From 38ada4aa4bad4b76556c39f21a832942b9a8f6e7 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Wed, 17 Apr 2024 07:11:34 -0700 Subject: [PATCH 001/165] Short preamble for the README, explaining why this clone exists --- README.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/README.md b/README.md index 69827e9dd747..1b8e0abbb939 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,16 @@ +# Triton-CPU + +A long-lived development branch to build an experimental CPU backend for [Triton](https://github.com/openai/triton). + +This repository clones the main Triton repository, but we intend to minimize +divergences in the core (and ideally upstream anything that needs to change and +isn't too CPU-specific). Most of the CPU work should be in a backend +subdirectory (similar to how GPU vendors are supported today). We're starting +with a clone to give ourselves maximum development flexibility as this project +gets off the ground! + +# Upstream README +
Triton logo
From 8600c20cc8ff4407ee6b841a78135965fec1ea7b Mon Sep 17 00:00:00 2001 From: Facebook Community Bot Date: Wed, 1 May 2024 12:39:20 -0700 Subject: [PATCH 002/165] OSS Automated Fix: Addition of Code of Conduct (#1) --- CODE_OF_CONDUCT.md | 80 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 CODE_OF_CONDUCT.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000000..3232ed665566 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq From 875b15fbb0f99a843800bb037b4519fe86020157 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Wed, 1 May 2024 17:39:53 -0700 Subject: [PATCH 003/165] [BACKEND][CPU] Initial plumbing for cpu backend (#2) * [BACKEND][CPU] Implement the empty cpu backend * Run clang-format * Fix yadf error Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- python/setup.py | 3 +- python/triton/backends/driver.py | 11 +++ python/triton/runtime/driver.py | 10 +++ third_party/cpu/CMakeLists.txt | 7 ++ third_party/cpu/backend/compiler.py | 106 +++++++++++++++++++++++++ third_party/cpu/backend/driver.py | 68 ++++++++++++++++ third_party/cpu/include/CMakeLists.txt | 0 third_party/cpu/lib/CMakeLists.txt | 0 third_party/cpu/triton_cpu.cc | 24 ++++++ 9 files changed, 228 insertions(+), 1 deletion(-) create mode 100644 third_party/cpu/CMakeLists.txt create mode 100644 third_party/cpu/backend/compiler.py create mode 100644 third_party/cpu/backend/driver.py create mode 100644 third_party/cpu/include/CMakeLists.txt create mode 100644 third_party/cpu/lib/CMakeLists.txt create mode 100644 third_party/cpu/triton_cpu.cc diff --git a/python/setup.py b/python/setup.py index 7be7a472ff9b..886693762e3c 100644 --- a/python/setup.py +++ b/python/setup.py @@ -584,7 +584,8 @@ def build_extension(self, ext): url_func=lambda system, arch, version: f"https://developer.download.nvidia.com/compute/cuda/redist/cuda_cupti/{system}-{arch}/cuda_cupti-{system}-{arch}-{version}-archive.tar.xz", ) -backends = [*BackendInstaller.copy(["nvidia", "amd"]), *BackendInstaller.copy_externals()] + +backends = [*BackendInstaller.copy(["nvidia", "amd", "cpu"]), *BackendInstaller.copy_externals()] def add_link_to_backends(): diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index 6606b21ca8a2..72347735476b 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -51,3 +51,14 @@ def __init__(self): # TODO: remove once TMA is cleaned up def assemble_tensormap_to_arg(self, tensormaps_info, args): return args + + +class CPUDriverBase(DriverBase): + + def __init__(self): + # Right now, we just provide dummy functions. + # TODO: Consider better engineering the code only intended for GPU in jit.py. + self.get_device_capability = lambda idx: (0, 0) + self.get_current_stream = lambda idx: 0 + self.get_current_device = lambda: 0 + self.set_current_device = lambda idx: None diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index c3b97a764145..4cf1aea8e494 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -1,9 +1,19 @@ +import os + from ..backends import backends from ..backends import DriverBase def _create_driver(): + if os.getenv("TRITON_CPU_BACKEND", "0") == "1": + if "cpu" not in backends: + raise RuntimeError("TRITON_CPU_BACKEND is set, but CPU backend is unavailable.") + return backends["cpu"].driver() + actives = [x.driver for x in backends.values() if x.driver.is_active()] + if len(actives) >= 2 and backends["cpu"].driver.is_active(): + print("Both CPU and GPU backends are available. Using the GPU backend.") + actives.remove(backends["cpu"].driver) if len(actives) != 1: raise RuntimeError(f"{len(actives)} active drivers ({actives}). There should only be one.") return actives[0]() diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt new file mode 100644 index 000000000000..c1f3dd476784 --- /dev/null +++ b/third_party/cpu/CMakeLists.txt @@ -0,0 +1,7 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) +if(TRITON_BUILD_PYTHON_MODULE) + add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc) +endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py new file mode 100644 index 000000000000..e888dde3f4f1 --- /dev/null +++ b/third_party/cpu/backend/compiler.py @@ -0,0 +1,106 @@ +import functools +import hashlib +import re + +from dataclasses import dataclass +from typing import Any + +from triton._C.libtriton import cpu, ir, passes +from triton.backends.compiler import BaseBackend + + +@dataclass(frozen=True) +class CPUOptions: + # GPU-specific options are used in several places. + # For now, we just provide dummy values. + num_warps: int = 0 + num_stages: int = 0 + num_ctas: int = 0 + cluster_dims: tuple = (1, 1, 1) + debug: bool = False + + # TODO: We may introduce CPU-specific options like # of cores. + + def __post_init__(self): + pass + + def hash(self): + hash_dict = dict(self.__dict__) + key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) + return hashlib.sha256(key.encode("utf-8")).hexdigest() + + +class CPUBackend(BaseBackend): + + @staticmethod + def supports_target(target: tuple): + return target[0] == "cpu" + + def __init__(self, target: tuple) -> None: + super().__init__(target) + self.binary_ext = "exe" + + def parse_options(self, opts) -> Any: + args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} + return CPUOptions(**args) + + def pack_metadata(self, metadata): + return metadata + + def get_codegen_implementation(self): + codegen_fns = dict() + return codegen_fns + + def load_dialects(self, ctx): + cpu.load_dialects(ctx) + + @staticmethod + def make_ttir(mod, metadata, opt): + # This is the same as the Nvidia backend. + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) + passes.ttir.add_combine(pm) + passes.common.add_canonicalizer(pm) + passes.ttir.add_reorder_broadcast(pm) + passes.common.add_cse(pm) + passes.common.add_licm(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) + return mod + + @staticmethod + def make_ttcir(mod, metadata, opt): + # TODO: + return mod + + @staticmethod + def make_llir(src, metadata, options): + # TODO: + metadata["shared"] = 0 + return src + + @staticmethod + def make_exe(src, metadata, options): + # Right now, src is just TTIR. Extract kernel name from tt.func. + names = re.findall(r"\s+tt.func public @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) + assert len(names) == 1 + metadata["name"] = names[0] + + # TODO: Call llc to create an executable. + return src + + def add_stages(self, stages, options): + stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) + stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) + stages["exe"] = lambda src, metadata: self.make_exe(src, metadata, options) + + @functools.lru_cache() + def hash(self): + # TODO: Get more detailed CPU info like raw brand name with supported ISAs. + # Right now it would only return a simple string like "x86_64" or "aarch64". + import platform + + return f"{platform.machine()}" diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py new file mode 100644 index 000000000000..f795521fe204 --- /dev/null +++ b/third_party/cpu/backend/driver.py @@ -0,0 +1,68 @@ +from triton.backends.driver import CPUDriverBase + +# ------------------------ +# Utils +# ------------------------ + + +class CPUUtils(object): + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CPUUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + pass + + @staticmethod + def get_device_properties(device): + # This is just dummy for now. We will need to implement driver.c. + return { + "max_shared_mem": 0, + "multiprocessor_count": 0, + "sm_clock_rate": 0, + "mem_clock_rate": 0, + "mem_bus_width": 0, + } + + @staticmethod + def load_binary(name, kernel_asm, shared, device): + # This is just dummy for now. We will need to implement driver.c. + return (None, kernel_asm, 0, 0) + + +# ------------------------ +# Launcher +# ------------------------ + + +def make_launcher(constants, signature, ids): + pass + + +class CPULauncher(object): + + def __init__(self, src, metadata): + # TODO: + self.launch = lambda *args, **kwargs: None + + def __call__(self, *args, **kwargs): + print("CPULauncher.__call__") + self.launch(*args, **kwargs) + + +class CPUDriver(CPUDriverBase): + + def __init__(self): + self.utils = CPUUtils() + self.launcher_cls = CPULauncher + super().__init__() + + def get_current_target(self): + # Capability and warp size are zeros for CPU. + return ("cpu", 0, 0) + + @staticmethod + def is_active(): + return True diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc new file mode 100644 index 000000000000..1ccfb19f1526 --- /dev/null +++ b/third_party/cpu/triton_cpu.cc @@ -0,0 +1,24 @@ +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/IR/Constants.h" +#include "llvm/Support/TargetSelect.h" +#include +#include +#include + +#include + +namespace py = pybind11; + +void init_triton_passes_ttcpuir(py::module &&m) { + // TODO: +} + +void init_triton_cpu(py::module &&m) { + auto passes = m.def_submodule("passes"); + init_triton_passes_ttcpuir(passes.def_submodule("ttcpuir")); + + m.def("load_dialects", [](mlir::MLIRContext &context) { + // TODO: + }); +} From 015e50df2969065c034391ce79ca334a8c203d43 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Mon, 6 May 2024 11:00:50 -0700 Subject: [PATCH 004/165] [BACKEND][CPU] Create TritonCPU and conversion dialects (#3) --- bin/RegisterTritonDialects.h | 2 + include/triton/Analysis/Utility.h | 1 + include/triton/Conversion/CMakeLists.txt | 2 + .../Conversion/TritonCPUToLLVM/CMakeLists.txt | 3 + .../Conversion/TritonCPUToLLVM/Passes.h | 29 +++++ .../Conversion/TritonCPUToLLVM/Passes.td | 25 ++++ .../TritonCPUToLLVM/TypeConverter.h | 20 ++++ .../TritonToTritonCPU/CMakeLists.txt | 3 + .../Conversion/TritonToTritonCPU/Passes.h | 15 +++ .../Conversion/TritonToTritonCPU/Passes.td | 23 ++++ .../TritonToTritonCPU/TritonToTritonCPUPass.h | 18 +++ .../TritonToTritonGPU/CMakeLists.txt | 2 +- .../Conversion/TritonToTritonGPU/Passes.h | 4 +- .../Conversion/TritonToTritonGPU/Passes.td | 4 +- include/triton/Dialect/CMakeLists.txt | 1 + .../triton/Dialect/TritonCPU/CMakeLists.txt | 2 + .../triton/Dialect/TritonCPU/IR/Attributes.h | 9 ++ .../Dialect/TritonCPU/IR/CMakeLists.txt | 21 ++++ include/triton/Dialect/TritonCPU/IR/Dialect.h | 17 +++ .../Dialect/TritonCPU/IR/TritonCPUAttrDefs.td | 25 ++++ .../Dialect/TritonCPU/IR/TritonCPUDialect.td | 29 +++++ .../TritonCPU/IR/TritonCPUInterfaces.h | 6 + .../Dialect/TritonCPU/IR/TritonCPUOps.td | 12 ++ .../Dialect/TritonCPU/IR/TritonCPUTypes.td | 26 +++++ include/triton/Dialect/TritonCPU/IR/Types.h | 10 ++ .../TritonCPU/Transforms/CMakeLists.txt | 3 + .../Dialect/TritonCPU/Transforms/Passes.h | 16 +++ .../Dialect/TritonCPU/Transforms/Passes.td | 6 + .../Transforms/TritonCPUConversion.h | 31 +++++ lib/Conversion/CMakeLists.txt | 2 + lib/Conversion/TritonCPUToLLVM/CMakeLists.txt | 15 +++ .../TritonCPUToLLVM/TritonCPUToLLVM.cpp | 88 ++++++++++++++ .../TritonCPUToLLVM/TypeConverter.cpp | 16 +++ .../TritonToTritonCPU/CMakeLists.txt | 15 +++ .../TritonToTritonCPU/TritonCPUConversion.cpp | 108 ++++++++++++++++++ .../TritonToTritonCPUPass.cpp | 41 +++++++ .../TritonToTritonGPU/CMakeLists.txt | 2 +- lib/Dialect/CMakeLists.txt | 1 + lib/Dialect/TritonCPU/CMakeLists.txt | 2 + lib/Dialect/TritonCPU/IR/CMakeLists.txt | 11 ++ lib/Dialect/TritonCPU/IR/Dialect.cpp | 42 +++++++ lib/Dialect/TritonCPU/IR/Types.cpp | 38 ++++++ .../TritonCPU/Transforms/CMakeLists.txt | 13 +++ python/src/passes.cc | 6 + third_party/cpu/CMakeLists.txt | 6 +- third_party/cpu/backend/compiler.py | 58 +++++++++- third_party/cpu/backend/driver.py | 1 - third_party/cpu/include/CMakeLists.txt | 0 third_party/cpu/lib/CMakeLists.txt | 0 third_party/cpu/triton_cpu.cc | 16 ++- 50 files changed, 828 insertions(+), 18 deletions(-) create mode 100644 include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt create mode 100644 include/triton/Conversion/TritonCPUToLLVM/Passes.h create mode 100644 include/triton/Conversion/TritonCPUToLLVM/Passes.td create mode 100644 include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h create mode 100644 include/triton/Conversion/TritonToTritonCPU/CMakeLists.txt create mode 100644 include/triton/Conversion/TritonToTritonCPU/Passes.h create mode 100644 include/triton/Conversion/TritonToTritonCPU/Passes.td create mode 100644 include/triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h create mode 100644 include/triton/Dialect/TritonCPU/CMakeLists.txt create mode 100644 include/triton/Dialect/TritonCPU/IR/Attributes.h create mode 100644 include/triton/Dialect/TritonCPU/IR/CMakeLists.txt create mode 100644 include/triton/Dialect/TritonCPU/IR/Dialect.h create mode 100644 include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td create mode 100644 include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td create mode 100644 include/triton/Dialect/TritonCPU/IR/TritonCPUInterfaces.h create mode 100644 include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td create mode 100644 include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td create mode 100644 include/triton/Dialect/TritonCPU/IR/Types.h create mode 100644 include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt create mode 100644 include/triton/Dialect/TritonCPU/Transforms/Passes.h create mode 100644 include/triton/Dialect/TritonCPU/Transforms/Passes.td create mode 100644 include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h create mode 100644 lib/Conversion/TritonCPUToLLVM/CMakeLists.txt create mode 100644 lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp create mode 100644 lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp create mode 100644 lib/Conversion/TritonToTritonCPU/CMakeLists.txt create mode 100644 lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp create mode 100644 lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp create mode 100644 lib/Dialect/TritonCPU/CMakeLists.txt create mode 100644 lib/Dialect/TritonCPU/IR/CMakeLists.txt create mode 100644 lib/Dialect/TritonCPU/IR/Dialect.cpp create mode 100644 lib/Dialect/TritonCPU/IR/Types.cpp create mode 100644 lib/Dialect/TritonCPU/Transforms/CMakeLists.txt delete mode 100644 third_party/cpu/include/CMakeLists.txt delete mode 100644 third_party/cpu/lib/CMakeLists.txt diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index ebd32ad4422f..4681acf8abf8 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -4,6 +4,7 @@ #include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h" #include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -71,6 +72,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { // TODO: register Triton & TritonGPU passes registry .insert + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +#define GEN_PASS_DECL +#include "triton/Conversion/TritonCPUToLLVM/Passes.h.inc" + +std::unique_ptr> createConvertTritonCPUToLLVMPass(); + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonCPUToLLVM/Passes.h.inc" + +} // namespace triton + +} // namespace mlir + +#endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/Passes.td b/include/triton/Conversion/TritonCPUToLLVM/Passes.td new file mode 100644 index 000000000000..a0bfd65c3d28 --- /dev/null +++ b/include/triton/Conversion/TritonCPUToLLVM/Passes.td @@ -0,0 +1,25 @@ +#ifndef TRITONCPU_CONVERSION_PASSES +#define TRITONCPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonCPUToLLVM : Pass<"convert-triton-cpu-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert TritonCPU to LLVM"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonCPUToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::LLVM::LLVMDialect", + "mlir::math::MathDialect", + "mlir::scf::SCFDialect", + "mlir::tensor::TensorDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + ]; +} + +#endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h b/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h new file mode 100644 index 000000000000..57f4ce78f091 --- /dev/null +++ b/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h @@ -0,0 +1,20 @@ +#ifndef TRITONCPU_CONVERSION_TRITONCPUTOLLVM_TYPECONVERTER_H +#define TRITONCPU_CONVERSION_TRITONCPUTOLLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/TritonCPU/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); +}; + +#endif diff --git a/include/triton/Conversion/TritonToTritonCPU/CMakeLists.txt b/include/triton/Conversion/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..66945e2242f1 --- /dev/null +++ b/include/triton/Conversion/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonCPU) +add_public_tablegen_target(TritonConversionToCPUPassIncGen) diff --git a/include/triton/Conversion/TritonToTritonCPU/Passes.h b/include/triton/Conversion/TritonToTritonCPU/Passes.h new file mode 100644 index 000000000000..4ec0411da1ab --- /dev/null +++ b/include/triton/Conversion/TritonToTritonCPU/Passes.h @@ -0,0 +1,15 @@ +#ifndef TRITON_CONVERSION_TO_CPU_PASSES_H +#define TRITON_CONVERSION_TO_CPU_PASSES_H + +#include "triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h" + +namespace mlir { +namespace triton { + +#define GEN_PASS_REGISTRATION +#include "triton/Conversion/TritonToTritonCPU/Passes.h.inc" + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton/Conversion/TritonToTritonCPU/Passes.td b/include/triton/Conversion/TritonToTritonCPU/Passes.td new file mode 100644 index 000000000000..a15bd15bfcd1 --- /dev/null +++ b/include/triton/Conversion/TritonToTritonCPU/Passes.td @@ -0,0 +1,23 @@ +#ifndef TRITON_CONVERSION_TO_CPU_PASSES +#define TRITON_CONVERSION_TO_CPU_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertTritonToTritonCPU: Pass<"convert-triton-to-tritoncpu", "mlir::ModuleOp"> { + let summary = "Convert Triton to TritonCPU"; + let description = [{ + + }]; + let constructor = "mlir::triton::createConvertTritonToTritonCPUPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::scf::SCFDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; + + let options = [ + ]; +} + +#endif diff --git a/include/triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h b/include/triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h new file mode 100644 index 000000000000..2e7acbd24548 --- /dev/null +++ b/include/triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h @@ -0,0 +1,18 @@ +#ifndef TRITON_CONVERSION_TRITONTOTRITONCPU_TRITONTOTRITONCPUPASS_H +#define TRITON_CONVERSION_TRITONTOTRITONCPU_TRITONTOTRITONCPUPASS_H + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { + +std::unique_ptr> createConvertTritonToTritonCPUPass(); + +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt b/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt index 99d90c4d75e6..51ad71b4c2f8 100644 --- a/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/include/triton/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -1,3 +1,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonGPU) -add_public_tablegen_target(TritonConversionPassIncGen) +add_public_tablegen_target(TritonConversionToGPUPassIncGen) diff --git a/include/triton/Conversion/TritonToTritonGPU/Passes.h b/include/triton/Conversion/TritonToTritonGPU/Passes.h index e159406b3ed4..112269bfb369 100644 --- a/include/triton/Conversion/TritonToTritonGPU/Passes.h +++ b/include/triton/Conversion/TritonToTritonGPU/Passes.h @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_PASSES_H -#define TRITON_CONVERSION_PASSES_H +#ifndef TRITON_CONVERSION_TO_GPU_PASSES_H +#define TRITON_CONVERSION_TO_GPU_PASSES_H #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" diff --git a/include/triton/Conversion/TritonToTritonGPU/Passes.td b/include/triton/Conversion/TritonToTritonGPU/Passes.td index f20c3604090e..81dc45a9ae59 100644 --- a/include/triton/Conversion/TritonToTritonGPU/Passes.td +++ b/include/triton/Conversion/TritonToTritonGPU/Passes.td @@ -1,5 +1,5 @@ -#ifndef TRITON_CONVERSION_PASSES -#define TRITON_CONVERSION_PASSES +#ifndef TRITON_CONVERSION_TO_GPU_PASSES +#define TRITON_CONVERSION_TO_GPU_PASSES include "mlir/Pass/PassBase.td" diff --git a/include/triton/Dialect/CMakeLists.txt b/include/triton/Dialect/CMakeLists.txt index 6ef40db00f52..c964bdcea534 100644 --- a/include/triton/Dialect/CMakeLists.txt +++ b/include/triton/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Triton) +add_subdirectory(TritonCPU) add_subdirectory(TritonGPU) add_subdirectory(TritonNvidiaGPU) diff --git a/include/triton/Dialect/TritonCPU/CMakeLists.txt b/include/triton/Dialect/TritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..9f57627c321f --- /dev/null +++ b/include/triton/Dialect/TritonCPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/include/triton/Dialect/TritonCPU/IR/Attributes.h b/include/triton/Dialect/TritonCPU/IR/Attributes.h new file mode 100644 index 000000000000..7d4b98019d50 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/Attributes.h @@ -0,0 +1,9 @@ +#ifndef TRITON_DIALECT_TRITONCPU_IR_ATTRIBUTES_H_ +#define TRITON_DIALECT_TRITONCPU_IR_ATTRIBUTES_H_ + +#include "triton/Dialect/TritonCPU/IR/TritonCPUInterfaces.h" + +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.h.inc" + +#endif // TRITON_DIALECT_TRITONCPU_IR_ATTRIBUTES_H_ diff --git a/include/triton/Dialect/TritonCPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonCPU/IR/CMakeLists.txt new file mode 100644 index 000000000000..ace7d4ee7439 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/CMakeLists.txt @@ -0,0 +1,21 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + +set(LLVM_TARGET_DEFINITIONS TritonCPUOps.td) +mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_cpu) +mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_cpu) +mlir_tablegen(Ops.h.inc -gen-op-decls) +mlir_tablegen(Ops.cpp.inc -gen-op-defs) +mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_cpu) +mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_cpu) +add_mlir_doc(TritonCPUDialect TritonCPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonCPUOps TritonCPUOps dialects/ -gen-op-doc) +add_public_tablegen_target(TritonCPUTableGen) + +set(LLVM_TARGET_DEFINITIONS TritonCPUAttrDefs.td) +mlir_tablegen(TritonCPUAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(TritonCPUAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(TritonCPUAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TritonCPUAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(TritonCPUAttrDefsIncGen) diff --git a/include/triton/Dialect/TritonCPU/IR/Dialect.h b/include/triton/Dialect/TritonCPU/IR/Dialect.h new file mode 100644 index 000000000000..e8e8de322bb4 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/Dialect.h @@ -0,0 +1,17 @@ +#ifndef TRITON_DIALECT_TRITONCPU_IR_DIALECT_H_ +#define TRITON_DIALECT_TRITONCPU_IR_DIALECT_H_ + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" + +// TritonCPU depends on Triton +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Attributes.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h.inc" +#include "triton/Dialect/TritonCPU/IR/Types.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonCPU/IR/Ops.h.inc" + +#endif // TRITON_DIALECT_TRITONCPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td new file mode 100644 index 000000000000..df933dd49511 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td @@ -0,0 +1,25 @@ +#ifndef TRITONCPU_ATTRDEFS +#define TRITONCPU_ATTRDEFS + +include "mlir/IR/AttrTypeBase.td" +include "triton/Dialect/TritonCPU/IR/TritonCPUDialect.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" + +//===----------------------------------------------------------------------===// +// TritonCPU Attribute Definitions +//===----------------------------------------------------------------------===// +def TritonCPU_AttrTrait : AttrInterface<"TritonCPU_AttrTrait"> { + let cppNamespace = "::mlir::triton::cpu"; +} + +class TritonCPU_Attr traits = [], + Dialect dialect = TritonCPU_Dialect, + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + + let description = [{ + WIP... + }]; +} + +#endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td new file mode 100644 index 000000000000..9ccac13f0b58 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td @@ -0,0 +1,29 @@ +#ifndef TRITONCPU_DIALECT +#define TRITONCPU_DIALECT + +include "mlir/IR/OpBase.td" + +def TritonCPU_Dialect : Dialect { + let name = "triton_cpu"; + + let cppNamespace = "::mlir::triton::cpu"; + + let hasOperationAttrVerify = 1; + + let description = [{ + Triton CPU Dialect. + }]; + + let dependentDialects = [ + "triton::TritonDialect", + "tensor::TensorDialect", + ]; + + let extraClassDeclaration = [{ + void registerTypes(); + }]; + + let useDefaultTypePrinterParser = 1; +} + +#endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUInterfaces.h b/include/triton/Dialect/TritonCPU/IR/TritonCPUInterfaces.h new file mode 100644 index 000000000000..de27597a76ef --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUInterfaces.h @@ -0,0 +1,6 @@ +#ifndef TRITON_CPU_DIALECT_INTERFACES_H +#define TRITON_CPU_DIALECT_INTERFACES_H + +#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrInterfaces.h.inc" + +#endif // TRITON_CPU_DIALECT_INTERFACES_H diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td new file mode 100644 index 000000000000..16d9e433e899 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -0,0 +1,12 @@ +#ifndef TRITONCPU_OPS +#define TRITONCPU_OPS + +include "triton/Dialect/TritonCPU/IR/TritonCPUDialect.td" +include "triton/Dialect/TritonCPU/IR/TritonCPUTypes.td" +include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" +include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "mlir/IR/OpBase.td" + +#endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td new file mode 100644 index 000000000000..ea31f877dab3 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td @@ -0,0 +1,26 @@ +#ifndef TRITONCPU_TYPES +#define TRITONCPU_TYPES + +include "triton/Dialect/TritonCPU/IR/TritonCPUDialect.td" +include "mlir/IR/AttrTypeBase.td" + +class TTC_TypeDef traits = []> + : TypeDef { + let mnemonic = _mnemonic; +} + +def TTC_TokenType : TTC_TypeDef<"Token", "token"> { + let parameters = (ins "int32_t":$type); + + let builders = [ + TypeBuilder<(ins "unsigned":$type), [{ + return $_get($_ctxt, type); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + + let skipDefaultBuilders = 1; +} + +#endif diff --git a/include/triton/Dialect/TritonCPU/IR/Types.h b/include/triton/Dialect/TritonCPU/IR/Types.h new file mode 100644 index 000000000000..e8c984628aa5 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/IR/Types.h @@ -0,0 +1,10 @@ +#ifndef TRITONCPU_IR_TYPES_H_ +#define TRITONCPU_IR_TYPES_H_ + +#include "mlir/IR/TypeSupport.h" +#include "mlir/IR/Types.h" + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonCPU/IR/Types.h.inc" + +#endif // TRITON_IR_TYPES_H_ diff --git a/include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt b/include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..6aa946f64932 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonCPU) +add_public_tablegen_target(TritonCPUTransformsIncGen) diff --git a/include/triton/Dialect/TritonCPU/Transforms/Passes.h b/include/triton/Dialect/TritonCPU/Transforms/Passes.h new file mode 100644 index 000000000000..f31e47317080 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/Transforms/Passes.h @@ -0,0 +1,16 @@ +#ifndef TRITON_DIALECT_TRITONCPU_TRANSFORMS_PASSES_H_ +#define TRITON_DIALECT_TRITONCPU_TRANSFORMS_PASSES_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace triton { +namespace cpu {} // namespace cpu +} // namespace triton + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "triton/Dialect/TritonCPU/Transforms/Passes.h.inc" + +} // namespace mlir +#endif diff --git a/include/triton/Dialect/TritonCPU/Transforms/Passes.td b/include/triton/Dialect/TritonCPU/Transforms/Passes.td new file mode 100644 index 000000000000..a1d5271ee6e7 --- /dev/null +++ b/include/triton/Dialect/TritonCPU/Transforms/Passes.td @@ -0,0 +1,6 @@ +#ifndef TRITONCPU_PASSES +#define TRITONCPU_PASSES + +include "mlir/Pass/PassBase.td" + +#endif diff --git a/include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h b/include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h new file mode 100644 index 000000000000..01c24e19c60e --- /dev/null +++ b/include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h @@ -0,0 +1,31 @@ +//===----------------------------------------------------------------------===// +// +// Defines utilities to use while converting to the TritonCPU dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITON_DIALECT_TRITONCPU_TRANSFORMS_TRITONCPUCONVERSION_H_ +#define TRITON_DIALECT_TRITONCPU_TRANSFORMS_TRITONCPUCONVERSION_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class TritonCPUTypeConverter : public TypeConverter { +public: + TritonCPUTypeConverter(MLIRContext *context); + +private: + MLIRContext *context; +}; + +class TritonCPUConversionTarget : public ConversionTarget { + +public: + explicit TritonCPUConversionTarget(MLIRContext &ctx, + TritonCPUTypeConverter &typeConverter); +}; + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONCPU_TRANSFORMS_TRITONCPUCONVERSION_H_ diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 143a4375a811..5c3aa2c1a827 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,2 +1,4 @@ +add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) +add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..0dfa7cb5ca18 --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonCPUToLLVM + TypeConverter.cpp + TritonCPUToLLVM.cpp + + DEPENDS + TritonCPUConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + TritonAnalysis + TritonIR + TritonCPUIR + TritonCPUTransforms +) diff --git a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp new file mode 100644 index 000000000000..3646c92d80f9 --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp @@ -0,0 +1,88 @@ +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonCPUToLLVM/Passes.h" +#include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTTRITONCPUTOLLVM +#include "triton/Conversion/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; + +namespace { + +class TritonLLVMFunctionConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + } +}; + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addIllegalDialect(); + addIllegalDialect(); + addLegalOp(); + } +}; + +struct ConvertTritonCPUToLLVM + : public triton::impl::ConvertTritonCPUToLLVMBase { + using ConvertTritonCPUToLLVMBase< + ConvertTritonCPUToLLVM>::ConvertTritonCPUToLLVMBase; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + ConvertTritonCPUToLLVM() : ConvertTritonCPUToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + mlir::LowerToLLVMOptions option(context); + option.overrideIndexBitwidth(32); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + // TODO: + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { + +std::unique_ptr> createConvertTritonCPUToLLVMPass() { + return std::make_unique(); +} + +} // namespace triton +} // namespace mlir diff --git a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000000..6a5ac668dd52 --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp @@ -0,0 +1,16 @@ +#include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" + +using namespace mlir; +using namespace mlir::triton; + +TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, option, analysis) { + // Internally store bfloat16 as int16 + addConversion([&](BFloat16Type type) -> std::optional { + return IntegerType::get(type.getContext(), 16); + }); +} diff --git a/lib/Conversion/TritonToTritonCPU/CMakeLists.txt b/lib/Conversion/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..f1b612b9c291 --- /dev/null +++ b/lib/Conversion/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonToTritonCPU + TritonCPUConversion.cpp + TritonToTritonCPUPass.cpp + + DEPENDS + TritonConversionToCPUPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTransforms + TritonIR + TritonCPUIR + TritonCPUTransforms +) diff --git a/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp b/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp new file mode 100644 index 000000000000..dabc2a27a87b --- /dev/null +++ b/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp @@ -0,0 +1,108 @@ +#include "triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h" + +#include "mlir/IR/IRMapping.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include +#include + +using namespace mlir; +using namespace mlir::triton::cpu; + +// +// TypeConverter +// +TritonCPUTypeConverter::TritonCPUTypeConverter(MLIRContext *context) + : context(context) { + addConversion([](Type type) { return type; }); + + // Add encoding for tensor + addConversion([this](RankedTensorType tensorType) -> RankedTensorType { + // TODO: + return tensorType; + }); + + // Add encoding for tensor pointer + addConversion([this](triton::PointerType ptrType) -> triton::PointerType { + // Check whether tensor pointer `tt.ptr>` + auto pointeeTensorType = + ptrType.getPointeeType().dyn_cast(); + if (pointeeTensorType == nullptr) + return ptrType; + + // Add layout into the tensor + auto convertedTensorType = convertType(pointeeTensorType); + return triton::PointerType::get(convertedTensorType, + ptrType.getAddressSpace()); + }); + + // + // Materializations + // + // This will be called when (newArgType != origArgType) + // This will create newArg, and map(origArg, newArg) + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonCPU conversion"); + return std::nullopt; + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, + Location loc) -> std::optional { + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonCPU Conversion"); + return std::nullopt; + }); + + // This will be called when (desiredType != newOperandType) + // where, desiredType = typeConverter->convertType(origType) + // NOTE: only for remapped values. + addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, Location loc) { + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonCPU Conversion"); + return std::nullopt; + }); +} + +// +// TritonCPUConversion +// +TritonCPUConversionTarget::TritonCPUConversionTarget( + MLIRContext &context, TritonCPUTypeConverter &typeConverter) + : ConversionTarget(context) { + // TODO: we should also verify ops of TritonCPUDialect + addLegalDialect(); + + // Some ops from SCF are illegal + addIllegalOp(); + + addDynamicallyLegalDialect([&](Operation *op) { + bool hasLegalRegions = true; + for (auto ®ion : op->getRegions()) { + hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); + } + if (hasLegalRegions && typeConverter.isLegal(op)) { + return true; + } + return false; + }); + + // We have requirements for the data layouts + addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { + Attribute aEncoding = + dotOp.getA().getType().cast().getEncoding(); + Attribute bEncoding = + dotOp.getB().getType().cast().getEncoding(); + // TODO: + return false; + }); +} diff --git a/lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp b/lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp new file mode 100644 index 000000000000..44c41636a3f3 --- /dev/null +++ b/lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp @@ -0,0 +1,41 @@ +#include "triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h" +#include "llvm/ADT/APSInt.h" +#include + +#define GEN_PASS_CLASSES +#include "triton/Conversion/TritonToTritonCPU/Passes.h.inc" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +class ConvertTritonToTritonCPU + : public ConvertTritonToTritonCPUBase { +public: + ConvertTritonToTritonCPU() = default; + + void runOnOperation() override { + // TODO: + } +}; + +} // namespace + +std::unique_ptr> +mlir::triton::createConvertTritonToTritonCPUPass() { + return std::make_unique<::ConvertTritonToTritonCPU>(); +} diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index fb5f7156f9aa..02f3ce2157da 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -3,7 +3,7 @@ add_triton_library(TritonToTritonGPU TritonToTritonGPUPass.cpp DEPENDS - TritonConversionPassIncGen + TritonConversionToGPUPassIncGen LINK_LIBS PUBLIC MLIRIR diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 6ef40db00f52..c964bdcea534 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Triton) +add_subdirectory(TritonCPU) add_subdirectory(TritonGPU) add_subdirectory(TritonNvidiaGPU) diff --git a/lib/Dialect/TritonCPU/CMakeLists.txt b/lib/Dialect/TritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..9f57627c321f --- /dev/null +++ b/lib/Dialect/TritonCPU/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/Dialect/TritonCPU/IR/CMakeLists.txt b/lib/Dialect/TritonCPU/IR/CMakeLists.txt new file mode 100644 index 000000000000..67bf1bb1b9d4 --- /dev/null +++ b/lib/Dialect/TritonCPU/IR/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(TritonCPUIR + Dialect.cpp + Types.cpp + + DEPENDS + TritonCPUTableGen + TritonCPUAttrDefsIncGen + + LINK_LIBS PUBLIC + TritonIR +) diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp new file mode 100644 index 000000000000..e28a65358dca --- /dev/null +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -0,0 +1,42 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" + +#include + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::triton::cpu; + +//===----------------------------------------------------------------------===// +// Attribute methods +//===----------------------------------------------------------------------===// +#define GET_ATTRDEF_CLASSES +#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc" + +void TritonCPUDialect::initialize() { + registerTypes(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" +#include "triton/Dialect/TritonCPU/IR/OpsEnums.cpp.inc" + >(); +} + +// verify TritonCPU ops +LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // TODO: fill this. + return success(); +} diff --git a/lib/Dialect/TritonCPU/IR/Types.cpp b/lib/Dialect/TritonCPU/IR/Types.cpp new file mode 100644 index 000000000000..b6a17786bac2 --- /dev/null +++ b/lib/Dialect/TritonCPU/IR/Types.cpp @@ -0,0 +1,38 @@ +#include "triton/Dialect/TritonCPU/IR/Types.h" +#include "mlir/IR/DialectImplementation.h" // required by `Types.cpp.inc` +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" // required by `Types.cpp.inc` + +using namespace mlir; +using namespace mlir::triton::cpu; + +#define GET_TYPEDEF_CLASSES +#include "triton/Dialect/TritonCPU/IR/Types.cpp.inc" + +Type TokenType::parse(AsmParser &parser) { + if (parser.parseLess()) + return Type(); + + int type = 1; + if (parser.parseInteger(type)) + return Type(); + + if (parser.parseGreater()) + return Type(); + + return TokenType::get(parser.getContext(), type); +} + +void TokenType::print(AsmPrinter &printer) const { + printer << "<" << getType() << ">"; +} + +//===----------------------------------------------------------------------===// +// Triton Dialect +//===----------------------------------------------------------------------===// +void ::mlir::triton::cpu::TritonCPUDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "triton/Dialect/TritonCPU/IR/Types.cpp.inc" + >(); +} diff --git a/lib/Dialect/TritonCPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonCPU/Transforms/CMakeLists.txt new file mode 100644 index 000000000000..1714215b9434 --- /dev/null +++ b/lib/Dialect/TritonCPU/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(TritonCPUTransforms + + DEPENDS + TritonCPUTransformsIncGen + + LINK_LIBS PUBLIC + MLIRTransforms + MLIRTransformUtils + TritonAnalysis + TritonIR + TritonCPUIR + MLIRTransformUtils +) diff --git a/python/src/passes.cc b/python/src/passes.cc index 619ece2e3455..c365aaf43589 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -6,6 +6,7 @@ #include "triton/Analysis/Allocation.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" +#include "triton/Conversion/TritonToTritonCPU/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -44,6 +45,8 @@ void init_triton_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", createConvertTritonToTritonGPUPass, const std::string &, int, int, int); + ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", + createConvertTritonToTritonCPUPass); } void init_triton_passes_ttgpuir(py::module &&m) { @@ -78,6 +81,8 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUCoalesceAsyncCopy); } +void init_triton_passes_ttcpuir(py::module &&m) {} + void init_triton_passes_convert(py::module &&m) { using namespace mlir; ADD_PASS_WRAPPER_0("add_scf_to_cf", createConvertSCFToCFPass); @@ -96,6 +101,7 @@ void init_triton_passes(py::module &&m) { init_triton_passes_common(m.def_submodule("common")); init_triton_passes_convert(m.def_submodule("convert")); init_triton_passes_ttir(m.def_submodule("ttir")); + init_triton_passes_ttcpuir(m.def_submodule("ttcpuir")); init_triton_passes_ttgpuir(m.def_submodule("ttgpuir")); init_triton_passes_llvmir(m.def_submodule("llvmir")); } diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index c1f3dd476784..683889547b0a 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,7 +1,3 @@ -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) -add_subdirectory(include) -add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) - add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc) + add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM) endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index e888dde3f4f1..47b44ca84cdc 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -1,11 +1,12 @@ import functools import hashlib +import os import re from dataclasses import dataclass from typing import Any -from triton._C.libtriton import cpu, ir, passes +from triton._C.libtriton import cpu, ir, llvm, passes from triton.backends.compiler import BaseBackend @@ -17,6 +18,7 @@ class CPUOptions: num_stages: int = 0 num_ctas: int = 0 cluster_dims: tuple = (1, 1, 1) + extern_libs: dict = None debug: bool = False # TODO: We may introduce CPU-specific options like # of cores. @@ -72,14 +74,66 @@ def make_ttir(mod, metadata, opt): @staticmethod def make_ttcir(mod, metadata, opt): + # TTIR -> TTCIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.ttir.add_convert_to_ttcpuir(pm) + + # # TODO: + # + + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + pm.run(mod) return mod @staticmethod def make_llir(src, metadata, options): + mod = src + # TritonCPU -> LLVM-IR (MLIR) + pm = ir.pass_manager(mod.context) + pm.enable_debug() + passes.convert.add_scf_to_cf(pm) + passes.convert.add_index_to_llvmir(pm) + + cpu.passes.ttcpuir.add_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + + passes.convert.add_scf_to_cf(pm) + passes.convert.add_cf_to_llvmir(pm) + passes.convert.add_arith_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0": + passes.llvmir.add_di_scope(pm) + pm.run(mod) + + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) + llvm.init_targets() + context = llvm.context() + llvm_mod = llvm.to_module(mod, context) + # TODO: + if not llvm_mod: + metadata["shared"] = 0 + return src + + if options.extern_libs: + paths = [path for (name, path) in options.extern_libs] + llvm.link_extern_libs(llvm_mod, paths) + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) + + # CPU doesn't have SMEM, but just to make it work for now. metadata["shared"] = 0 - return src + + # Cleanup + ret = str(llvm_mod) + del llvm_mod + del context + return ret @staticmethod def make_exe(src, metadata, options): diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index f795521fe204..a6cf99f742b2 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -48,7 +48,6 @@ def __init__(self, src, metadata): self.launch = lambda *args, **kwargs: None def __call__(self, *args, **kwargs): - print("CPULauncher.__call__") self.launch(*args, **kwargs) diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 1ccfb19f1526..302951d04d59 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -1,5 +1,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "triton/Conversion/TritonCPUToLLVM/Passes.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "llvm/IR/Constants.h" #include "llvm/Support/TargetSelect.h" #include @@ -10,15 +12,21 @@ namespace py = pybind11; -void init_triton_passes_ttcpuir(py::module &&m) { - // TODO: +void init_triton_cpu_passes_ttcpuir(py::module &&m) { + using namespace mlir::triton; + m.def("add_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); + }); } void init_triton_cpu(py::module &&m) { auto passes = m.def_submodule("passes"); - init_triton_passes_ttcpuir(passes.def_submodule("ttcpuir")); + init_triton_cpu_passes_ttcpuir(passes.def_submodule("ttcpuir")); m.def("load_dialects", [](mlir::MLIRContext &context) { - // TODO: + mlir::DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); }); } From 3df2a1df69ab2a35fdb093670ca3e38b9badc670 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Mon, 6 May 2024 14:27:35 -0700 Subject: [PATCH 005/165] Update README.md A quick addition on how to use it. --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 1b8e0abbb939..dba7f2d0345b 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,18 @@ subdirectory (similar to how GPU vendors are supported today). We're starting with a clone to give ourselves maximum development flexibility as this project gets off the ground! +# How to use it? + +Build it like a normal Triton, but just pass TRITON_CPU_BACKEND=1 to use the CPU backend over a GPU backend, if any. + +``` +TRITON_CPU_BACKEND=1 python3 tutorials/01-vector-add.py +``` + +**NOTE: It's still work in progress.** + +--- + # Upstream README
From d8a8211820eb0e32d415b60c9803701b025ed678 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Mon, 13 May 2024 06:34:46 -0700 Subject: [PATCH 006/165] Convert tt.func and tt.return (#4) Summary: This is stll a kind of the boilerplate and basic lowering for the first milestone (compiling vector addition). This PR firstly lowers `tt.func` and `tt.return`. Test Plan: It can safely compile an empty kernel. ``` @triton.jit def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): return ``` > TRITON_ENABLE_LLVM_DEBUG=1 TRITON_CPU_BACKEND=1 python3 empty_kerne.py ``` //===-------------------------------------------===// Legalizing operation : 'tt.func'(0x73be2a0) { * Fold { } -> FAILURE : unable to fold * Pattern : 'tt.func -> ()' { Trying to match "(anonymous namespace)::FuncOpConversion" ** Insert : 'llvm.func'(0x6c04c70) ** Insert Block into : 'llvm.func'(0x6c04c70) ** Insert Block into : 'llvm.func'(0x6c04c70) ** Erase : 'tt.func'(0x73be2a0) "(anonymous namespace)::FuncOpConversion" result 1 //===-------------------------------------------===// Legalizing operation : 'llvm.func'(0x6c04c70) { } -> SUCCESS : operation marked legal by the target //===-------------------------------------------===// ... //===-------------------------------------------===// Legalizing operation : 'tt.return'(0x73efeb0) { "tt.return"() : () -> () * Fold { } -> FAILURE : unable to fold * Pattern : 'tt.return -> ()' { Trying to match "(anonymous namespace)::ReturnOpConversion" ** Insert : 'llvm.return'(0x73c0f00) ** Replace : 'tt.return'(0x73efeb0) "(anonymous namespace)::ReturnOpConversion" result 1 //===-------------------------------------------===// Legalizing operation : 'llvm.return'(0x73c0f00) { "llvm.return"() : () -> () } -> SUCCESS : operation marked legal by the target //===-------------------------------------------===// } -> SUCCESS : pattern applied successfully ``` --- .../PatternTritonCPUOpToLLVM.h | 36 +++++++++++++ .../TritonCPUToLLVM/TypeConverter.h | 2 + .../Conversion/TritonCPUToLLVM/Utility.h | 26 +++++++++ lib/Conversion/TritonCPUToLLVM/CMakeLists.txt | 2 + .../TritonCPUToLLVM/ControlFlowOpToLLVM.cpp | 37 +++++++++++++ .../TritonCPUToLLVM/FuncOpToLLVM.cpp | 54 +++++++++++++++++++ .../TritonCPUToLLVM/TritonCPUToLLVM.cpp | 26 ++++++++- .../TritonCPUToLLVM/TypeConverter.cpp | 15 ++++++ third_party/cpu/backend/compiler.py | 4 +- 9 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h create mode 100644 include/triton/Conversion/TritonCPUToLLVM/Utility.h create mode 100644 lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp create mode 100644 lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp diff --git a/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h b/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h new file mode 100644 index 000000000000..d2212eb34009 --- /dev/null +++ b/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h @@ -0,0 +1,36 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace triton { +// Some populate* functions have name collisions with the ones for GPUs. +namespace cpu { + +constexpr int patternBenefitDefault = 1; +constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; +constexpr int patternBenefitClampOptimizedPattern = 20; +constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; + +void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h b/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h index 57f4ce78f091..8ed9e6d4d849 100644 --- a/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h +++ b/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h @@ -15,6 +15,8 @@ class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, const DataLayoutAnalysis *analysis = nullptr); + + Type convertTritonPointerType(triton::PointerType type); }; #endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/Utility.h b/include/triton/Conversion/TritonCPUToLLVM/Utility.h new file mode 100644 index 000000000000..08d3b5e061a8 --- /dev/null +++ b/include/triton/Conversion/TritonCPUToLLVM/Utility.h @@ -0,0 +1,26 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include "llvm/Support/ErrorHandling.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir { +namespace LLVM { + +// TODO: Not sure we need this for CPU backends. +inline bool isKernel(FunctionOpInterface funcOp) { + return funcOp.getVisibility() == SymbolTable::Visibility::Public; +} + +} // namespace LLVM +} // namespace mlir + +#endif diff --git a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt index 0dfa7cb5ca18..175115628597 100644 --- a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt @@ -1,4 +1,6 @@ add_triton_library(TritonCPUToLLVM + ControlFlowOpToLLVM.cpp + FuncOpToLLVM.cpp TypeConverter.cpp TritonCPUToLLVM.cpp diff --git a/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp new file mode 100644 index 000000000000..a270c0d60845 --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp @@ -0,0 +1,37 @@ +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" +#include "llvm/Support/ErrorHandling.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + if (funcOp->hasAttr("cpu.kernel")) { + if (op.getNumOperands() > 0) { + return rewriter.notifyMatchFailure( + op, "Kernel functions do not support return with operands"); + } + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + } else { + llvm_unreachable("Not implemented"); + } + return success(); + } +}; + +} // namespace + +void mlir::triton::cpu::populateControlFlowOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000000..9ecd470345ad --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,54 @@ +#include "mlir/Support/LogicalResult.h" +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +namespace { + +using namespace mlir; +using namespace mlir::triton; + +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!LLVM::isKernel(funcOp)) { + llvm_unreachable("Not implemented"); + } + + LLVM::LLVMFuncOp newFuncOp = + *mlir::convertFuncOpToLLVMFuncOp(funcOp, rewriter, *getTypeConverter()); + if (!newFuncOp) { + return failure(); + } + + auto ctx = funcOp->getContext(); + if (LLVM::isKernel(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr("cpu.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + } else { + llvm_unreachable("Not implemented"); + } + + rewriter.eraseOp(funcOp); + return success(); + } +}; + +} // namespace + +void mlir::triton::cpu::populateFuncOpConversionPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp index 3646c92d80f9..28d320df32d3 100644 --- a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp +++ b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp @@ -14,6 +14,7 @@ #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonCPUToLLVM/Passes.h" +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" #include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" @@ -71,7 +72,30 @@ struct ConvertTritonCPUToLLVM TritonCPUToLLVMTypeConverter typeConverter(context, option); TritonLLVMConversionTarget convTarget(*context); - // TODO: + // Lower functions + { + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMFunctionConversionTarget funcTarget(*context); + RewritePatternSet funcPatterns(context); + mlir::triton::cpu::populateFuncOpConversionPattern( + typeConverter, funcPatterns, + mlir::triton::cpu::patternBenefitDefault); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); + } + + RewritePatternSet patterns(context); + int benefit = + mlir::triton::cpu::patternBenefitPrioritizeOverLLVMConversions; + mlir::triton::cpu::populateControlFlowOpToLLVMPattern(typeConverter, + patterns, benefit); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); } }; diff --git a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp index 6a5ac668dd52..e8ca0810c195 100644 --- a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp @@ -1,6 +1,7 @@ #include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Conversion/MLIRTypes.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::triton; @@ -9,8 +10,22 @@ TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( MLIRContext *ctx, LowerToLLVMOptions &option, const DataLayoutAnalysis *analysis) : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + // Internally store bfloat16 as int16 addConversion([&](BFloat16Type type) -> std::optional { return IntegerType::get(type.getContext(), 16); }); } + +Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (pointeeType.isa()) { + llvm_unreachable("Not implemented"); + } + return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); +} diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 47b44ca84cdc..84564cabef0c 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -137,8 +137,8 @@ def make_llir(src, metadata, options): @staticmethod def make_exe(src, metadata, options): - # Right now, src is just TTIR. Extract kernel name from tt.func. - names = re.findall(r"\s+tt.func public @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) + # Just a quick hack while developing the backend. + names = re.findall(r"\s+define void @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) assert len(names) == 1 metadata["name"] = names[0] From 2752fa83474da4f7c0e91449d4b8d486771028ef Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Tue, 14 May 2024 15:49:18 -0700 Subject: [PATCH 007/165] [BACKEND][CPU] Convert tt.get_program_id and tt.print (Hello World) (#1) Summary: As title, `tl.program_id` needs to be supported first. As of now, we think pid will be provided as additional function arguments to the kernel. So, getting program_id is mapped to reading one of the last three arguments. I also quickly implemented `tl.device_print` or `print`, only for scalar types for a quick "Hello, World!" testing. Test Plan: Tested with a simple example: ``` @triton.jit def add_kernel(...): pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. foo = pid + 42 tl.device_print("Hello, World!", foo, pid) ``` The resulting .llir is valid: ``` @printfFormat_1 = internal constant [31 x i8] c"pid (%u, %u, %u) test: %u, %u\0A\00" declare !dbg !3 i32 @printf(ptr, ...) define void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, i32 %4, i32 %5, i32 %6) !dbg !7 { %8 = add i32 %4, 42, !dbg !8 %9 = call i32 (ptr, ...) @printf(ptr @printfFormat_0, i32 %4, i32 %5, i32 %6, i32 %8, i32 %4) ret void, !dbg !9 } ``` Tried to compile with a fake main function: ``` > % cat main.c extern void add_kernel(float*, float*, float*, int, int, int, int); int main() { add_kernel(0, 0, 0, 4, 5, 6, 7); } > % llc -filetype=obj add_kernel.llir && clang -o a.out add_kernel.llir.o main.c > % ./a.out pid (5, 6, 7) Hello, World!: 47, 5 ``` --- .../TritonCPUToLLVM/CPUTargetInfo.h | 22 +++ .../PatternTritonCPUOpToLLVM.h | 8 ++ .../Conversion/TritonCPUToLLVM/Utility.h | 13 +- lib/Conversion/TritonCPUToLLVM/CMakeLists.txt | 3 + .../TritonCPUToLLVM/CPUTargetInfo.cpp | 49 +++++++ .../TritonCPUToLLVM/PrintOpToLLVM.cpp | 131 ++++++++++++++++++ .../TritonCPUToLLVM/SPMDOpToLLVM.cpp | 39 ++++++ .../TritonCPUToLLVM/TritonCPUToLLVM.cpp | 5 + 8 files changed, 261 insertions(+), 9 deletions(-) create mode 100644 include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h create mode 100644 lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp create mode 100644 lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp create mode 100644 lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp diff --git a/include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h b/include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h new file mode 100644 index 000000000000..66f6b57b1c57 --- /dev/null +++ b/include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h @@ -0,0 +1,22 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" + +namespace mlir::triton::cpu { +class CPUTargetInfo { +public: + // Note: we may revisit for different CPU ISAs like AVX and Neon. + CPUTargetInfo() {} + + Value programId(ConversionPatternRewriter &rewriter, Location loc, + LLVM::LLVMFuncOp funcOp, int axis) const; + + void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const; + + ~CPUTargetInfo() {} +}; +} // namespace mlir::triton::cpu +#endif // TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H diff --git a/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h b/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h index d2212eb34009..f5cd3612dac5 100644 --- a/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h +++ b/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h @@ -1,7 +1,9 @@ #ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H #define TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H +#include "CPUTargetInfo.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" using namespace mlir; @@ -17,6 +19,11 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; constexpr int patternBenefitClampOptimizedPattern = 20; constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; +void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const cpu::CPUTargetInfo &targetInfo, + PatternBenefit benefit); + void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit); @@ -27,6 +34,7 @@ void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const CPUTargetInfo &targetInfo, PatternBenefit benefit); } // namespace cpu diff --git a/include/triton/Conversion/TritonCPUToLLVM/Utility.h b/include/triton/Conversion/TritonCPUToLLVM/Utility.h index 08d3b5e061a8..8562271340a1 100644 --- a/include/triton/Conversion/TritonCPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonCPUToLLVM/Utility.h @@ -12,15 +12,10 @@ using namespace mlir; using namespace mlir::triton; -namespace mlir { -namespace LLVM { +// TODO: Do better refactoring. +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -// TODO: Not sure we need this for CPU backends. -inline bool isKernel(FunctionOpInterface funcOp) { - return funcOp.getVisibility() == SymbolTable::Visibility::Public; -} - -} // namespace LLVM -} // namespace mlir +#undef DEBUG_TYPE +#define DEBUG_TYPE "ttcpu_to_llvm" #endif diff --git a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt index 175115628597..db507557fb22 100644 --- a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt @@ -1,6 +1,9 @@ add_triton_library(TritonCPUToLLVM ControlFlowOpToLLVM.cpp + CPUTargetInfo.cpp FuncOpToLLVM.cpp + PrintOpToLLVM.cpp + SPMDOpToLLVM.cpp TypeConverter.cpp TritonCPUToLLVM.cpp diff --git a/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp b/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp new file mode 100644 index 000000000000..8dd050b80bbf --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp @@ -0,0 +1,49 @@ +#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" + +namespace { +LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("printf"); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *context = rewriter.getContext(); + + // int printf(char* format, ...) + SmallVector argsType{ptr_ty(context)}; + auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(context), funcName, + funcType); +} +} // namespace + +namespace mlir::triton::cpu { + +Value CPUTargetInfo::programId(ConversionPatternRewriter &rewriter, + Location loc, LLVM::LLVMFuncOp funcOp, + int axis) const { + assert(axis >= 0 && axis < 3); + + // program_id for CPU is provided as function arguments. The last three + // arguments are __grid0 to __grid2 of i32. + assert(funcOp && funcOp.getArguments().size() >= 3); + return funcOp.getArgument(funcOp.getArguments().size() - 3 + axis); +} + +void CPUTargetInfo::printf(ConversionPatternRewriter &rewriter, + Value formatStrStart, int /*formatStrByteCount*/, + ValueRange args) const { + auto loc = UnknownLoc::get(rewriter.getContext()); + SmallVector formatStrAndArgs{formatStrStart}; + for (auto arg : args) { + formatStrAndArgs.push_back(arg); + } + call(getPrintfDeclaration(rewriter), formatStrAndArgs); +} +} // namespace mlir::triton::cpu diff --git a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp new file mode 100644 index 000000000000..96a1c5d1619f --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp @@ -0,0 +1,131 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +struct PrintOpConversion : public ConvertOpToLLVMPattern { + explicit PrintOpConversion(LLVMTypeConverter &typeConverter, + const CPUTargetInfo &targetInfo, + PatternBenefit benefit) + : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + auto getPid = [&](int axis) { + return targetInfo.programId( + rewriter, loc, op->getParentOfType(), axis); + }; + SmallVector values = {getPid(0), getPid(1), getPid(2)}; + + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "pid (" << getFormatSubstr(values[0]) << ", " + << getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2]) + << ")" << op.getPrefix(); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); + if (op.getOperand(i).getType().dyn_cast()) { + llvm_unreachable("Not implemented for tensor types"); + } + + // Only support scalars for now. + assert(elems.size() == 1); + if (i != 0) { + os << ", "; + } + os << getFormatSubstr(elems[0]); + values.push_back(elems[0]); + } + + llPrintf(formatStr, values, rewriter); + rewriter.eraseOp(op); + return success(); + } + + // TODO: This code is the same as the GPU-backend code. Consider refactoring. + std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt) const { + Type type = value.getType(); + if (type.isa()) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = + "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } else if (hex) { + prefix += "0"; + prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isSignedInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "lli"; + else + return prefix + "i"; + } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "llu"; + else + return prefix + "u"; + } + assert(false && "not supported type"); + return ""; + } + + Value llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter, + int *formatStrByteCount = nullptr) const { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), + rewriter, "printfFormat_", msgNewline); + targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); + if (formatStrByteCount) + *formatStrByteCount = msgNewline.size_in_bytes(); + return msgValue; + } + +protected: + const CPUTargetInfo &targetInfo; +}; + +} // namespace + +void mlir::triton::cpu::populatePrintOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const CPUTargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp new file mode 100644 index 000000000000..65fef7a7d0d5 --- /dev/null +++ b/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp @@ -0,0 +1,39 @@ +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" +#include "triton/Conversion/TritonCPUToLLVM/Utility.h" + +namespace { + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +struct GetProgramIdOpConversion + : public ConvertOpToLLVMPattern { + explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, + const CPUTargetInfo &targetInfo, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value programId = targetInfo.programId( + rewriter, op->getLoc(), op->getParentOfType(), + op.getAxisAsInt()); + rewriter.replaceOp(op, programId); + return success(); + } + +private: + const CPUTargetInfo &targetInfo; +}; + +} // namespace + +void mlir::triton::cpu::populateSPMDOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const CPUTargetInfo &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, targetInfo, benefit); +} diff --git a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp index 28d320df32d3..cb15f87ee206 100644 --- a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp +++ b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp @@ -89,10 +89,15 @@ struct ConvertTritonCPUToLLVM } RewritePatternSet patterns(context); + mlir::triton::cpu::CPUTargetInfo targetInfo; int benefit = mlir::triton::cpu::patternBenefitPrioritizeOverLLVMConversions; mlir::triton::cpu::populateControlFlowOpToLLVMPattern(typeConverter, patterns, benefit); + mlir::triton::cpu::populatePrintOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); + mlir::triton::cpu::populateSPMDOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); From 1f6133553fd4d1c1dccb000d38f768f770123e94 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Wed, 15 May 2024 22:19:36 -0700 Subject: [PATCH 008/165] Quick patches to make it work after rebasing (#3) --- lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp | 4 ++-- lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp | 2 +- lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp | 6 +++--- third_party/cpu/backend/compiler.py | 6 +++--- third_party/cpu/backend/driver.py | 4 +++- 5 files changed, 12 insertions(+), 10 deletions(-) diff --git a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp index 96a1c5d1619f..b424cf8e37b7 100644 --- a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp +++ b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp @@ -39,7 +39,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { for (size_t i = 0; i < op.getNumOperands(); i++) { auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); - if (op.getOperand(i).getType().dyn_cast()) { + if (dyn_cast(op.getOperand(i).getType())) { llvm_unreachable("Not implemented for tensor types"); } @@ -61,7 +61,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { std::string getFormatSubstr(Value value, bool hex = false, std::optional width = std::nullopt) const { Type type = value.getType(); - if (type.isa()) { + if (isa(type)) { return "%p"; } // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the diff --git a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp index e8ca0810c195..72ef796fdabb 100644 --- a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp @@ -24,7 +24,7 @@ Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( triton::PointerType type) { auto ctx = type.getContext(); auto pointeeType = type.getPointeeType(); - if (pointeeType.isa()) { + if (isa(pointeeType)) { llvm_unreachable("Not implemented"); } return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); diff --git a/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp b/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp index dabc2a27a87b..97948404bdbf 100644 --- a/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp +++ b/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp @@ -26,7 +26,7 @@ TritonCPUTypeConverter::TritonCPUTypeConverter(MLIRContext *context) addConversion([this](triton::PointerType ptrType) -> triton::PointerType { // Check whether tensor pointer `tt.ptr>` auto pointeeTensorType = - ptrType.getPointeeType().dyn_cast(); + dyn_cast(ptrType.getPointeeType()); if (pointeeTensorType == nullptr) return ptrType; @@ -99,9 +99,9 @@ TritonCPUConversionTarget::TritonCPUConversionTarget( // We have requirements for the data layouts addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { Attribute aEncoding = - dotOp.getA().getType().cast().getEncoding(); + cast(dotOp.getA().getType()).getEncoding(); Attribute bEncoding = - dotOp.getB().getType().cast().getEncoding(); + cast(dotOp.getB().getType()).getEncoding(); // TODO: return false; }); diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 84564cabef0c..3c293cdf468f 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -7,7 +7,7 @@ from typing import Any from triton._C.libtriton import cpu, ir, llvm, passes -from triton.backends.compiler import BaseBackend +from triton.backends.compiler import BaseBackend, GPUTarget @dataclass(frozen=True) @@ -35,8 +35,8 @@ def hash(self): class CPUBackend(BaseBackend): @staticmethod - def supports_target(target: tuple): - return target[0] == "cpu" + def supports_target(target: GPUTarget): + return target.backend == "cpu" def __init__(self, target: tuple) -> None: super().__init__(target) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index a6cf99f742b2..3f3816a99b9f 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,3 +1,4 @@ +from triton.backends.compiler import GPUTarget from triton.backends.driver import CPUDriverBase # ------------------------ @@ -60,7 +61,8 @@ def __init__(self): def get_current_target(self): # Capability and warp size are zeros for CPU. - return ("cpu", 0, 0) + # TODO: GPUTarget naming isn't obviously good. + return GPUTarget("cpu", 0, 0) @staticmethod def is_active(): From 7b56e5a59eeb019304a4664e8b2202177e4dfd15 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 2 May 2024 10:52:16 -0700 Subject: [PATCH 009/165] Support basic lowering through vector dialect in CPU backend. Signed-off-by: Ilya Enkovich --- bin/RegisterTritonDialects.h | 29 +- .../Conversion/TritonCPUToLLVM/CMakeLists.txt | 2 +- .../Dialect/TritonCPU/IR/TritonCPUAttrDefs.td | 5 +- .../Dialect/TritonCPU/IR/TritonCPUDialect.td | 3 + .../Dialect/TritonCPU/IR/TritonCPUOps.td | 51 +++ lib/Conversion/CMakeLists.txt | 4 +- lib/Dialect/TritonCPU/IR/Dialect.cpp | 37 +- python/src/llvm.cc | 67 ++++ python/src/passes.cc | 5 +- python/test/unit/language/test_core.py | 32 ++ third_party/cpu/CMakeLists.txt | 5 + third_party/cpu/backend/__init__.py | 0 third_party/cpu/backend/compiler.py | 73 ++-- third_party/cpu/backend/driver.cpp | 224 ++++++++++++ third_party/cpu/backend/driver.py | 326 ++++++++++++++++-- third_party/cpu/include/CMakeLists.txt | 2 + .../include/TritonCPUToLLVM/CMakeLists.txt | 3 + .../cpu/include/TritonCPUToLLVM/Passes.h | 36 ++ .../cpu/include/TritonCPUToLLVM/Passes.td | 46 +++ .../include/TritonToTritonCPU/CMakeLists.txt | 3 + .../cpu/include/TritonToTritonCPU/Passes.h | 37 ++ .../cpu/include/TritonToTritonCPU/Passes.td | 62 ++++ third_party/cpu/lib/CMakeLists.txt | 2 + .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 13 + .../cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp | 276 +++++++++++++++ .../TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp | 98 ++++++ .../lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp | 277 +++++++++++++++ .../cpu/lib/TritonCPUToLLVM/Pipeline.cpp | 25 ++ .../cpu/lib/TritonCPUToLLVM/TypeConverter.cpp | 43 +++ .../cpu/lib/TritonCPUToLLVM/TypeConverter.h | 22 ++ .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 15 + .../lib/TritonToTritonCPU/ConvertDotOp.cpp | 102 ++++++ .../ConvertElementwiseOps.cpp | 300 ++++++++++++++++ .../TritonToTritonCPU/ConvertMemoryOps.cpp | 277 +++++++++++++++ .../lib/TritonToTritonCPU/ConvertPtrOps.cpp | 191 ++++++++++ .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 26 ++ .../lib/TritonToTritonCPU/TypeConverter.cpp | 51 +++ .../cpu/lib/TritonToTritonCPU/TypeConverter.h | 19 + third_party/cpu/triton_cpu.cc | 45 ++- 39 files changed, 2754 insertions(+), 80 deletions(-) create mode 100644 third_party/cpu/backend/__init__.py create mode 100644 third_party/cpu/backend/driver.cpp create mode 100644 third_party/cpu/include/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUToLLVM/Passes.h create mode 100644 third_party/cpu/include/TritonCPUToLLVM/Passes.td create mode 100644 third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonToTritonCPU/Passes.h create mode 100644 third_party/cpu/include/TritonToTritonCPU/Passes.td create mode 100644 third_party/cpu/lib/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h create mode 100644 third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 4681acf8abf8..ca922e824793 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -17,6 +17,8 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "cpu/include/TritonToTritonCPU/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" @@ -69,15 +71,22 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); + // CPU passes + mlir::triton::cpu::registerTritonToTritonCPUPasses(); + mlir::triton::cpu::registerTritonToTritonCPUPipeline(); + mlir::triton::cpu::registerTritonCPUToLLVMPasses(); + mlir::triton::cpu::registerTritonCPUToLLVMPipeline(); + // TODO: register Triton & TritonGPU passes - registry - .insert(); + registry.insert(); } diff --git a/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt b/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt index 64b36523d35d..0936dff12d91 100644 --- a/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt +++ b/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt @@ -1,3 +1,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) -add_public_tablegen_target(TritonCPUConversionPassIncGen) +add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen) diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td index df933dd49511..57f6c7c9bd71 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td @@ -17,9 +17,8 @@ class TritonCPU_Attr traits = [], string baseCppClass = "::mlir::Attribute"> : AttrDef { - let description = [{ - WIP... - }]; + let description = [{TritonCPU attr.}]; + let attrName = "triton.cpu." # attrMnemonic; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td index 9ccac13f0b58..260db2743046 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td @@ -17,6 +17,7 @@ def TritonCPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", "tensor::TensorDialect", + "mlir::memref::MemRefDialect", ]; let extraClassDeclaration = [{ @@ -24,6 +25,8 @@ def TritonCPU_Dialect : Dialect { }]; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index 16d9e433e899..bb7417ebd03e 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -7,6 +7,57 @@ include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +class TTC_Op traits = []> : + Op { +} + +def TTC_ExtractMemRefOp : TTC_Op<"extract_memref", [NoMemoryEffect]> { + let summary = "Extract base memref from a block pointer"; + + let description = [{ + Extract base memref from a block pointer. It covers whole base tensor memory, + not only the block referenced. Base pointer, shape, and strides are used + in the resulting memref. Offsets and block shape are ignored. + + }]; + + let arguments = (ins TT_TensorPtr:$src); + + let results = (outs AnyRankedOrUnrankedMemRef:$result); + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTC_ExtractIndicesOp : TTC_Op<"extract_indices", [NoMemoryEffect]> { + let summary = "Extract indices from a block pointer."; + + let description = [{ + Extract indices that can be used to access the block using its base memref. + Indices are supposed to be used for vector loads/stores with the base + memref extracted from the same block pointer. + }]; + + let arguments = (ins TT_TensorPtr:$src); + + let results = (outs Variadic:$result); + + let builders = [ + OpBuilder<(ins "Value":$src)> + ]; + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} #endif diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 5c3aa2c1a827..83db4ae41607 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(TritonToTritonCPU) +#add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) -add_subdirectory(TritonCPUToLLVM) +#add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp index e28a65358dca..e5eb53caf686 100644 --- a/lib/Dialect/TritonCPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -2,16 +2,19 @@ #include +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/TypeSwitch.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" + using namespace mlir; +using namespace mlir::triton; using namespace mlir::triton::cpu; //===----------------------------------------------------------------------===// @@ -20,6 +23,35 @@ using namespace mlir::triton::cpu; #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc" +void ExtractMemRefOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) {} + +void ExtractIndicesOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) {} + +/// Parse an attribute registered to this dialect. +::mlir::Attribute +TritonCPUDialect::parseAttribute(::mlir::DialectAsmParser &parser, + ::mlir::Type type) const { + llvm_unreachable("parse stub called"); +} + +/// Print an attribute registered to this dialect. +void TritonCPUDialect::printAttribute(::mlir::Attribute attr, + ::mlir::DialectAsmPrinter &os) const { + llvm_unreachable("print stub called"); +} + +void ExtractIndicesOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, Value src) { + assert(triton::isTensorPointerType(src.getType()) && + "Unexecpeted source type"); + auto tensorTy = dyn_cast( + dyn_cast(src.getType()).getPointeeType()); + SmallVector resTypes(tensorTy.getRank(), builder.getIndexType()); + build(builder, state, resTypes, src); +} + void TritonCPUDialect::initialize() { registerTypes(); @@ -34,6 +66,9 @@ void TritonCPUDialect::initialize() { >(); } +#define GET_OP_CLASSES +#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" + // verify TritonCPU ops LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { diff --git a/python/src/llvm.cc b/python/src/llvm.cc index c86bf671a7df..ca7b9c911643 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -3,6 +3,8 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -21,6 +23,7 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/TargetParser/Host.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/Instrumentation/AddressSanitizer.h" @@ -403,6 +406,70 @@ void init_triton_llvm(py::module &&m) { py::arg("flags") = std::vector{}, py::arg("enable_fp_fusion") = false); + m.def("set_host_target", [](llvm::Module *mod) { + mod->setTargetTriple(llvm::sys::getDefaultTargetTriple()); + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error); + std::unique_ptr machine{target->createTargetMachine( + mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {}, + llvm::Reloc::PIC_)}; + mod->setDataLayout(machine->createDataLayout()); + }); + + m.def( + "translate_to_host_asm", + [](std::string llvmIR) -> py::object { + std::string res; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + res = translateLLVMIRToASM( + *module, llvm::sys::getDefaultTargetTriple(), + llvm::sys::getHostCPUName().str(), "", {}, false, false); + } + return py::str(res); + }, + ret::take_ownership); + + m.def( + "translate_to_bc", + [](const std::string llvmIR) -> py::object { + py::gil_scoped_release allow_threads; + // create LLVM module + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + // Write bitcode to a buffer. + llvm::SmallVector buf; + llvm::BitcodeWriter writer(buf); + writer.writeModule(*module); + writer.writeStrtab(); + std::string bitcode(buf.begin(), buf.end()); + return py::bytes(bitcode); + }, + ret::take_ownership); + m.def( "translate_to_asm", [](std::string llvmIR, std::string triple, std::string proc, diff --git a/python/src/passes.cc b/python/src/passes.cc index c365aaf43589..9e34f6ad7fed 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -45,8 +45,8 @@ void init_triton_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", createConvertTritonToTritonGPUPass, const std::string &, int, int, int); - ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", - createConvertTritonToTritonCPUPass); + // ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", + // createConvertTritonToTritonCPUPass); } void init_triton_passes_ttgpuir(py::module &&m) { @@ -89,6 +89,7 @@ void init_triton_passes_convert(py::module &&m) { ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); + ADD_PASS_WRAPPER_0("add_math_to_llvmir", createConvertMathToLLVMPass); } void init_triton_passes_llvmir(py::module &&m) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6c8f130e94d8..b42d715fea7b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -304,6 +304,7 @@ def filter_layouts(layouts): return [l for l in layouts if is_layout_applicable(l)] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) def test_empty_kernel(dtype_x, device): @@ -543,6 +544,7 @@ def test_dtype_codegen(): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -603,6 +605,7 @@ def promote_to_fp32(dtype_x, dtype_y): test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) def test_addptr(dtype, order, device): @@ -629,6 +632,7 @@ def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): np.testing.assert_allclose(y, to_numpy(y_tri)) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y", [ # (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes @@ -649,6 +653,7 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas) +@pytest.mark.cpu def test_unsigned_name_mangling(device): # Test that uint32 and int32 are mangled differently by the compiler SIZE = 128 @@ -685,6 +690,7 @@ def kernel(O1, O2, X, Y, SIZE: tl.constexpr): # test bitwise ops # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -709,6 +715,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes @@ -731,6 +738,7 @@ def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): ops = ['==', '!=', '>', '<', '>=', '<='] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "dtype_x, dtype_y, op, mode_x, mode_y", @@ -755,6 +763,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # --------------- # test broadcast # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) def test_broadcast(dtype, device): @@ -789,6 +798,7 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con # ---------- +@pytest.mark.cpu @pytest.mark.interpreter def test_slice(device): @@ -820,6 +830,7 @@ def slice_kernel(XBLOCK: tl.constexpr): # ------------------ +@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_slice(device): dst = torch.empty(128, device=device) @@ -835,6 +846,7 @@ def _kernel(dst): # ---------------- # test expand_dims # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims(device): @@ -883,6 +895,7 @@ def expand_dims_kernel(dummy, N: tl.constexpr): expand_dims_kernel[(1, )](dummy_tensor, N) +@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims_error_cases(device): @@ -946,6 +959,7 @@ def duplicate_dim2(dummy, N: tl.constexpr): # ---------------------------- # test invalid program id axis # ---------------------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_pid_axis(device): dst = torch.empty(128, device=device) @@ -962,6 +976,7 @@ def _kernel(dst): # --------------- # test where # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1014,6 +1029,7 @@ def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl. assert (z == to_numpy(z_tri)).all() +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_where_broadcast(num_ctas, device): @@ -1058,6 +1074,7 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr", [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') @@ -1072,6 +1089,7 @@ def test_unary_op(dtype_x, expr, num_ctas, device): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) @@ -1082,6 +1100,7 @@ def test_math_op(dtype_x, expr, x, device): _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_erf_op(dtype, device): @@ -1103,6 +1122,7 @@ def kernel(Z, X, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_fma_op(dtype, device): @@ -1128,6 +1148,7 @@ def kernel(Z, X, Y, W, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1140,6 +1161,7 @@ def test_math_divide_op(expr, num_ctas, device): # ------------- # test precise math # ------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr_prec, expr_ref", [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), @@ -1180,6 +1202,7 @@ def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_abs(dtype_x, device): @@ -1225,6 +1248,7 @@ def abs_kernel(X, Z, SIZE: tl.constexpr): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_shapes_as_params(device): @@ -1297,6 +1321,7 @@ def make_ptr_str(name, shape): return f"{name} + {' + '.join(offsets)}" +@pytest.mark.cpu # TODO: handle `%4 = ttg.convert_layout %3 : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>`` @pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] @@ -1366,6 +1391,7 @@ def tuples_fn(a, b): a * b +@pytest.mark.cpu @pytest.mark.interpreter def test_tuples(device): @@ -1458,6 +1484,7 @@ def noinline_multi_values_fn(x, y, Z): tl.store(Z, z) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) def test_noinline(mode, device): @@ -1780,6 +1807,7 @@ def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ @@ -4729,6 +4757,7 @@ def kernel(VALUE, X): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @@ -4766,6 +4795,7 @@ def kernel(Z, X, Y): np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) +@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_shape(device): @@ -4779,6 +4809,7 @@ def kernel(X): np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) +@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_scalar_shape(device): @@ -4796,6 +4827,7 @@ def kernel(X, s): reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("formats", reshape_list) def test_reshape(formats, device): diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 683889547b0a..d8be71ad6c11 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,3 +1,8 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM) + target_link_libraries(TritonCPU PUBLIC MLIRMathToLibm) endif() diff --git a/third_party/cpu/backend/__init__.py b/third_party/cpu/backend/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 3c293cdf468f..357b5f448fe9 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -4,7 +4,7 @@ import re from dataclasses import dataclass -from typing import Any +from typing import Any, Tuple from triton._C.libtriton import cpu, ir, llvm, passes from triton.backends.compiler import BaseBackend, GPUTarget @@ -20,6 +20,8 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False + allowed_dot_input_precisions: Tuple[str] = ("ieee",) + allow_fp8e4nv: bool = False # TODO: We may introduce CPU-specific options like # of cores. @@ -40,7 +42,7 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) - self.binary_ext = "exe" + self.binary_ext = "bc" def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -62,7 +64,6 @@ def make_ttir(mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) - passes.ttir.add_rewrite_tensor_pointer(pm) passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) @@ -77,33 +78,34 @@ def make_ttcir(mod, metadata, opt): # TTIR -> TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - passes.ttir.add_convert_to_ttcpuir(pm) - - # - # TODO: - # - + cpu.passes.ttcpuir.add_triton_to_triton_cpu_pipeline(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) pm.run(mod) + metadata["cluster_dims"] = (opt.cluster_dims[0], opt.cluster_dims[1], opt.cluster_dims[2]) return mod @staticmethod def make_llir(src, metadata, options): + # warp-specialization mutates num_warps + num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + if num_warp_groups is not None: + metadata["num_warps"] *= num_warp_groups + metadata["threads_per_warp"] = 1 mod = src # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) - - cpu.passes.ttcpuir.add_to_llvmir(pm) - passes.common.add_canonicalizer(pm) - passes.common.add_cse(pm) - - passes.convert.add_scf_to_cf(pm) - passes.convert.add_cf_to_llvmir(pm) + cpu.passes.ttcpuir.add_triton_cpu_to_llvmir_pipeline(pm) + passes.convert.add_math_to_llvmir(pm) + cpu.passes.ttcpuir.add_math_to_libm(pm) + cpu.passes.ttcpuir.add_vector_to_llvmir(pm) + cpu.passes.ttcpuir.add_memref_to_llvmir(pm) passes.convert.add_arith_to_llvmir(pm) + cpu.passes.ttcpuir.add_func_to_llvmir(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) @@ -111,45 +113,40 @@ def make_llir(src, metadata, options): passes.llvmir.add_di_scope(pm) pm.run(mod) + # Find kernel fn + kernel_names = cpu.find_kernel_names(mod) + assert len(kernel_names) == 1, f"expected exactly 1 kernel in a module, got {kernel_names}" + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() llvm_mod = llvm.to_module(mod, context) - - # TODO: - if not llvm_mod: - metadata["shared"] = 0 - return src - - if options.extern_libs: - paths = [path for (name, path) in options.extern_libs] - llvm.link_extern_libs(llvm_mod, paths) + llvm.set_host_target(llvm_mod) + #if options.extern_libs: + # paths = [path for (name, path) in options.extern_libs] + # llvm.link_extern_libs(llvm_mod, paths) llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) - - # CPU doesn't have SMEM, but just to make it work for now. + # Get some metadata metadata["shared"] = 0 - - # Cleanup + metadata["name"] = kernel_names[0] ret = str(llvm_mod) del llvm_mod del context return ret @staticmethod - def make_exe(src, metadata, options): - # Just a quick hack while developing the backend. - names = re.findall(r"\s+define void @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) - assert len(names) == 1 - metadata["name"] = names[0] - - # TODO: Call llc to create an executable. - return src + def make_bc(src, metadata, options): + if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1": + print("********** Module ASM **********") + print(llvm.translate_to_host_asm(src)) + ret = llvm.translate_to_bc(src) + return ret def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) - stages["exe"] = lambda src, metadata: self.make_exe(src, metadata, options) + stages["bc"] = lambda src, metadata: self.make_bc(src, metadata, options) @functools.lru_cache() def hash(self): diff --git a/third_party/cpu/backend/driver.cpp b/third_party/cpu/backend/driver.cpp new file mode 100644 index 000000000000..babff3dfdebe --- /dev/null +++ b/third_party/cpu/backend/driver.cpp @@ -0,0 +1,224 @@ +//===- driver.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/TargetSelect.h" + +#include +#include +#include +#include +#include +#include +#include + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + + return Py_BuildValue("{s:i}", "max_shared_mem", 0); +} + +bool getBoolEnv(const std::string &env) { + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return (str == "on" || str == "true" || str == "1"); +} + +llvm::orc::ThreadSafeContext &getThreadSafeContext() { + static llvm::orc::ThreadSafeContext tsc; + static std::once_flag init_flag; + std::call_once(init_flag, []() { + auto context = std::make_unique(); + tsc = llvm::orc::ThreadSafeContext(std::move(context)); + }); + return tsc; +} + +std::string llvmErrToString(const llvm::Error &err) { + std::string res; + llvm::raw_string_ostream os(res); + os << err; + return res; +}; + +struct CompiledKernel { + std::unique_ptr execution_session; + std::unique_ptr data_layout; + std::unique_ptr mangle; + std::unique_ptr object_layer; + std::unique_ptr compiler_layer; + llvm::orc::JITDylib *dylib = nullptr; + + CompiledKernel() = default; + CompiledKernel(CompiledKernel &&) = default; + + ~CompiledKernel() { + if (execution_session) + llvm::cantFail(execution_session->endSession()); + } +}; + +std::vector> compiled_kernels; + +static PyObject *loadBitcode(PyObject *self, PyObject *args) { + const char *name; + int shared; + PyObject *py_bytes; + int devId; + + if (!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &devId)) { + std::cerr << "loadBitcode arg parse failed" << std::endl; + return NULL; + } + + std::string kernel_name = name; + size_t binary_size = PyBytes_Size(py_bytes); + const char *binary_ptr = PyBytes_AsString(py_bytes); + + llvm::LLVMContext context; + auto buf = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(binary_ptr, binary_size)); + auto mod = llvm::parseBitcodeFile(*buf, context); + if (!mod) { + std::cerr << "Failed to parse LLVM bitcode module" << std::endl; + return NULL; + } + + if (getBoolEnv("MLIR_ENABLE_DUMP")) { + llvm::errs() << "********** Loaded Module (kernel_name=" << name + << ") **********\n" + << **mod << "\n"; + } + + auto init_err = llvm::InitializeNativeTarget(); + if (init_err) { + std::cerr << "Failed to initialize native target." << std::endl; + return NULL; + } + + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + + auto self_epc = + llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create()); + + auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!detect_host_res) { + std::cerr << "Failed to initialize JITTargetMachineBuilder: " + << llvmErrToString(detect_host_res.takeError()); + return NULL; + } + llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); + + auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); + if (!data_layout_res) { + std::cerr << "Failed to initialize data layout: " + << llvmErrToString(data_layout_res.takeError()); + return NULL; + } + + CompiledKernel kernel; + kernel.execution_session = + std::make_unique(std::move(self_epc)); + kernel.data_layout = + std::make_unique(std::move(*data_layout_res)); + kernel.mangle = std::make_unique( + *kernel.execution_session, *kernel.data_layout); + kernel.object_layer = std::make_unique( + *kernel.execution_session, + []() { return std::make_unique(); }); + kernel.compiler_layer = std::make_unique( + *kernel.execution_session, *kernel.object_layer, + std::make_unique(std::move(tmb))); + + auto dylib_res = kernel.execution_session->createJITDylib("
"); + if (!dylib_res) { + std::cerr << "Failed to create initialize JITDylib: " + << llvmErrToString(dylib_res.takeError()); + return NULL; + } + + kernel.dylib = &(*dylib_res); + kernel.dylib->addGenerator(llvm::cantFail( + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + kernel.data_layout->getGlobalPrefix()))); + + // Compile module. + (**mod).setDataLayout(*kernel.data_layout); + llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); + auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm)); + if (err) { + std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); + return NULL; + } + + // Find kernel function pointer. + auto lookup_res = + kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name)); + if (!lookup_res) { + std::cerr << "Failed to find function " << std::string(name) + << "\nError: " << llvmErrToString(lookup_res.takeError()); + return NULL; + } + uint64_t fn_ptr = lookup_res->getAddress().getValue(); + + compiled_kernels.push_back( + std::make_unique(std::move(kernel))); + auto *kernel_ptr = compiled_kernels.back().get(); + + return Py_BuildValue("(KKii)", reinterpret_cast(kernel_ptr), + reinterpret_cast(fn_ptr), 0, 0); +} + +static PyObject *initContext(PyObject *self, PyObject *args) { + return Py_BuildValue("(K)", (uint64_t)0); +} + +static PyObject *initDevices(PyObject *self, PyObject *args) { + return Py_BuildValue("(i)", 1); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBitcode, METH_VARARGS, + "Load provided SPV into ZE driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cpu_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_cpu_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; +} diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 3f3816a99b9f..743684d2640f 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,5 +1,100 @@ +import os +import hashlib +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget -from triton.backends.driver import CPUDriverBase + +dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") +llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm") +llvm_root = os.path.expanduser(llvm_root) +llvm_dirs = os.listdir(llvm_root) +if len(llvm_dirs) == 1: + llvm_root = os.path.join(llvm_root, llvm_dirs[0]) +include_dir = [ + os.path.join(dirname, "include"), + os.path.join(llvm_root, "include"), +] +library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] +libraries = [ + "LLVMOrcJIT", + "LLVMPasses", + "LLVMX86CodeGen", + "LLVMX86AsmParser", + "LLVMX86Desc", + "LLVMX86Info", + "LLVMGlobalISel", + "LLVMSelectionDAG", + "LLVMHipStdPar", + "LLVMCoroutines", + "LLVMipo", + "LLVMFrontendOpenMP", + "LLVMInstrumentation", + "LLVMAsmPrinter", + "LLVMCodeGen", + "LLVMObjCARCOpts", + "LLVMLinker", + "LLVMVectorize", + "LLVMScalarOpts", + "LLVMInstCombine", + "LLVMFrontendOffloading", + "LLVMExecutionEngine", + "LLVMAggressiveInstCombine", + "LLVMTransformUtils", + "LLVMTarget", + "LLVMRuntimeDyld", + "LLVMJITLink", + "LLVMIRPrinter", + "LLVMBitWriter", + "LLVMAnalysis", + "LLVMProfileData", + "LLVMSymbolize", + "LLVMDebugInfoDWARF", + "LLVMObject", + "LLVMTextAPI", + "LLVMMCParser", + "LLVMMCDisassembler", + "LLVMMC", + "LLVMIRReader", + "LLVMCFGuard", + "LLVMBitReader", + "LLVMAsmParser", + "LLVMCore", + "LLVMBinaryFormat", + "LLVMOrcTargetProcess", + "LLVMTargetParser", + "LLVMRemarks", + "LLVMOrcShared", + "LLVMOption", + "LLVMDebugInfoCodeView", + "LLVMCodeGenTypes", + "LLVMBitstreamReader", + "LLVMSupport", + "LLVMDemangle", + "stdc++", +] + + +def compile_module_from_src(src, name): + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.cpp") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + # ------------------------ # Utils @@ -15,22 +110,12 @@ def __new__(cls): def __init__(self): pass + dirname = os.path.dirname(os.path.realpath(__file__)) + mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils") + self.load_binary = mod.load_binary - @staticmethod - def get_device_properties(device): - # This is just dummy for now. We will need to implement driver.c. - return { - "max_shared_mem": 0, - "multiprocessor_count": 0, - "sm_clock_rate": 0, - "mem_clock_rate": 0, - "mem_bus_width": 0, - } - - @staticmethod - def load_binary(name, kernel_asm, shared, device): - # This is just dummy for now. We will need to implement driver.c. - return (None, kernel_asm, 0, 0) + def get_device_properties(self, *args): + return {"max_shared_mem": 0} # ------------------------ @@ -38,27 +123,228 @@ def load_binary(name, kernel_asm, shared, device): # ------------------------ +def ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + def make_launcher(constants, signature, ids): - pass + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + arg_types = (', '.join(f"{ty_to_cpp(ty)}" for i, ty in signature.items()) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiOKOOOO" + args_format + args_list = ', '.join(f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + + # generate glue code + src = f""" +#include +#include +#include +#include + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include +#include + +using kernel_ptr_t = void(*)({arg_types}); + +typedef struct _DevicePtrInfo {{ + void* dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); + if(!ptr_info.dev_ptr) {{ + return ptr_info; + }} + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + // TODO: add OMP pragmas to run in parallel + for (uint32_t z = 0; z < gridZ; ++z) {{ + for (uint32_t y = 0; y < gridY; ++y) {{ + for (uint32_t x = 0; x < gridX; ++x) {{ + (*kernel_ptr)({args_list + ', ' if len(arg_decls) > 0 else ''} x, y, z); + }} + }} + }} +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + + + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + PyObject *py_obj_stream; + void* pKrnl; + + {' '.join([f"{_extracted_type(ty)} arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {', ' + arg_ptrs_list if len(signature) > 0 else ''})) {{ + return NULL; + }} + + void *pStream = PyLong_AsVoidPtr(py_obj_stream); + kernel_ptr_t kernel_ptr = reinterpret_cast(pKrnl); + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + run_omp_kernels(gridX, gridY, gridZ, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + if (PyErr_Occurred()) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_cpu_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_cpu_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src class CPULauncher(object): def __init__(self, src, metadata): - # TODO: - self.launch = lambda *args, **kwargs: None + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + src = make_launcher(constants, signature, ids) + mod = compile_module_from_src(src, "__triton_cpu_launcher") + self.launch = mod.launch def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) -class CPUDriver(CPUDriverBase): +class CPUDriver(DriverBase): def __init__(self): self.utils = CPUUtils() self.launcher_cls = CPULauncher super().__init__() + def get_current_device(self): + return 0 + + def get_current_stream(self, device): + return 0 + def get_current_target(self): # Capability and warp size are zeros for CPU. # TODO: GPUTarget naming isn't obviously good. diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt new file mode 100644 index 000000000000..fc9a19e52b0d --- /dev/null +++ b/third_party/cpu/include/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..64b36523d35d --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) +add_public_tablegen_target(TritonCPUConversionPassIncGen) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h new file mode 100644 index 000000000000..74f74b00870c --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -0,0 +1,36 @@ +#ifndef TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H +#define TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" + +std::unique_ptr> createFuncOpToLLVMPass(); +std::unique_ptr> createMemoryOpToLLVMPass(); +std::unique_ptr> createGetProgramIdOpToLLVMPass(); + +void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); +void registerTritonCPUToLLVMPipeline(); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td new file mode 100644 index 000000000000..c75b58b572f1 --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -0,0 +1,46 @@ +#ifndef TRITONCPU_CONVERSION_PASSES +#define TRITONCPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def FuncOpToLLVM : Pass<"triton-cpu-func-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert FuncOp to LLVM for CPU."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createFuncOpToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def MemoryOpToLLVM : Pass<"triton-cpu-memory-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton memory operations to LLVM for CPU."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createMemoryOpToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::memref::MemRefDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def GetProgramIdOpToLLVM : Pass<"triton-cpu-get-program-id-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton GetProgramId to LLVM for CPU."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createGetProgramIdOpToLLVMPass()"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..56e231273ed6 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonCPU) +add_public_tablegen_target(TritonToTritonCPUPassIncGen) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h new file mode 100644 index 000000000000..ab98a8741a16 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -0,0 +1,37 @@ +#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES_H +#define TRITONTOTRITONCPU_CONVERSION_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" + +std::unique_ptr> createConvertElementwiseOps(); +std::unique_ptr> createConvertMemoryOps(); +std::unique_ptr> createConvertPtrOps(); +std::unique_ptr> createConvertDotOp(); + +void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); +void registerTritonToTritonCPUPipeline(); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td new file mode 100644 index 000000000000..77e6528c6943 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -0,0 +1,62 @@ +#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES +#define TRITONTOTRITONCPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertMemoryOps : Pass<"triton-cpu-convert-memory-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton memory ops."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertMemoryOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertElementwiseOps : Pass<"triton-cpu-convert-elementwise-ops", "mlir::ModuleOp"> { + let summary = "Convert elementwise ops."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertElementwiseOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertPtrOps : Pass<"triton-cpu-convert-ptr-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton ops related to pointer arithmetics."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertPtrOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDotOp : Pass<"triton-cpu-convert-dot-op", "mlir::ModuleOp"> { + let summary = "Convert Triton DotOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertDotOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +#endif diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt new file mode 100644 index 000000000000..fc9a19e52b0d --- /dev/null +++ b/third_party/cpu/lib/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..884c9352ef1b --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(TritonCPUToLLVM + FuncOpToLLVM.cpp + GetProgramIdOpToLLVM.cpp + MemoryOpToLLVM.cpp + Pipeline.cpp + TypeConverter.cpp + + DEPENDS + TritonCPUToLLVMConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRVectorToLLVMPass +) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000000..5895341fc34b --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,276 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_FUNCOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendProgramIdArgs(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + // 1. Modify the function type to add new arguments. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(i32_ty); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add new arguments. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + SmallVector amendedArgAttrs; + if (funcOp.getAllArgAttrs()) + amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + // 3. Add a new arguments to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(funcTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto modifiedFuncOp = funcOp; + if (LLVM::isKernel(funcOp)) + modifiedFuncOp = amendProgramIdArgs(modifiedFuncOp, rewriter); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + modifiedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) + return failure(); + + // required by AxisInfoAnalysis + if (LLVM::isKernel(funcOp)) + rewriter.eraseOp(modifiedFuncOp); + rewriter.eraseOp(funcOp); + return success(); + } +}; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto funcOp = op->getParentOfType(); + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + auto loc = op.getLoc(); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = + insert_val(packedResultsTy, packedResults, it.value(), it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = callOp.getLoc(); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } +}; + +struct FuncOpToLLVM : public triton::impl::FuncOpToLLVMBase { + using FuncOpToLLVMBase::FuncOpToLLVMBase; + + FuncOpToLLVM() : FuncOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + // Lower tt.func + RewritePatternSet funcPatterns(context); + funcPatterns.add(typeConverter, + /*benefit=*/1); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, convTarget, std::move(funcPatterns)))) + return signalPassFailure(); + + // Lower tt.call, tt.return + int benefit = 10; + RewritePatternSet patterns(context); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createFuncOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp new file mode 100644 index 000000000000..4c593f1ff7aa --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp @@ -0,0 +1,98 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_GETPROGRAMIDOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// TODO: use enums to access struct fields. +struct GetProgramIdOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + assert(funcOp && "expected LLVM::FuncOp as a parent of GetProgramIdOp"); + auto args = funcOp.getArguments(); + // Last three args are x, y, z program ids. + auto argIdx = args.size() - 3 + op.getAxisAsInt(); + assert(argIdx < args.size() && "out-of-bounds arg index"); + assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); + rewriter.replaceOp(op, args[argIdx]); + return success(); + } +}; + +struct GetProgramIdOpToLLVM + : public triton::impl::GetProgramIdOpToLLVMBase { + using GetProgramIdOpToLLVMBase::GetProgramIdOpToLLVMBase; + + GetProgramIdOpToLLVM() : GetProgramIdOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createGetProgramIdOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000000..594495c4ab9d --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,277 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_MEMORYOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// TODO: use enums to access struct fields. +struct ExtractMemRefOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ExtractMemRefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); + auto memRefTy = cast(op.getType()); + auto rank = memRefTy.getRank(); + auto memRefStructTy = getTypeConverter()->convertType(op.getType()); + auto memRefStructFields = + cast(memRefStructTy).getBody(); + auto i64Ty = IntegerType::get(getContext(), 64); + + auto copyValue = [&](Value to, int64_t idxFrom, int64_t idxTo) { + auto valueTy = memRefStructFields[idxTo]; + Value val = rewriter.create( + loc, valueTy, tensorPtrStruct, idxFrom); + return rewriter.create(loc, memRefStructTy, to, val, + idxTo); + }; + + Value res = undef(memRefStructTy); + // Copy base. + res = copyValue(res, 0, 1); + // Use 0 offset. + res = rewriter.create(loc, memRefStructTy, res, + i64_val(0), 2); + // Copy shape. + res = copyValue(res, 2, 3); + // Copy strides. + res = copyValue(res, 3, 4); + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct ExtractIndicesOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ExtractIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); + auto rank = op.getNumResults(); + auto i64Ty = IntegerType::get(getContext(), 64); + SmallVector indices; + + for (int64_t i = 0; i < rank; i++) { + Value offs = rewriter.create( + loc, i64Ty, tensorPtrStruct, SmallVector{1, i}); + Value stride = rewriter.create( + loc, i64Ty, tensorPtrStruct, SmallVector{3, i}); + indices.push_back(rewriter.create(loc, offs, stride)); + } + + rewriter.replaceOp(op, indices); + + return success(); + } +}; + +struct MakeTensorPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto structTy = getTypeConverter()->convertType(op.getType()); + auto i64Ty = IntegerType::get(getContext(), 64); + + auto insertArray = [&](Value structVal, auto values, int64_t idx, + Type zextTo = nullptr) { + for (int64_t i = 0; i < static_cast(values.size()); ++i) { + Value val = values[i]; + if (zextTo) + val = rewriter.create(loc, zextTo, val); + structVal = rewriter.create( + loc, structTy, structVal, val, SmallVector{idx, i}); + } + return structVal; + }; + + Value res = undef(structTy); + // 0 - base pointer. + auto base = rewriter.getRemappedValue(op.getBase()); + res = rewriter.create(loc, structTy, res, base, 0); + // 1 - array for offsets. Promote values to i64. + res = insertArray(res, op.getOffsets(), 1, i64Ty); + // 2 - array for shape. + res = insertArray(res, op.getShape(), 2); + // 3 - array for strides. + res = insertArray(res, op.getStrides(), 3); + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct AdvanceOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto i64Ty = IntegerType::get(getContext(), 64); + Value res = rewriter.getRemappedValue(op.getPtr()); + Type structTy = res.getType(); + auto offsets = op.getOffsets(); + + for (int64_t i = 0; i < offsets.size(); ++i) { + auto oldOffset = rewriter.create( + loc, i64Ty, res, SmallVector{1, i}); + auto step = rewriter.create(loc, i64Ty, offsets[i]); + auto newOffset = rewriter.create(loc, oldOffset, step); + res = rewriter.create(loc, structTy, res, newOffset, + SmallVector{1, i}); + } + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct LoadOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type ptrTy = LLVM::LLVMPointerType::get(getContext()); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, ptr, 0, + op.getIsVolatile()); + return success(); + } +}; + +struct StoreOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value val = rewriter.getRemappedValue(op.getValue()); + rewriter.replaceOpWithNewOp(op, val, ptr); + return success(); + } +}; + +struct PtrToIntOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = rewriter.getRemappedValue(op.getSrc()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } +}; + +struct IntToPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = rewriter.getRemappedValue(op.getSrc()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } +}; + +struct MemoryOpToLLVM + : public triton::impl::MemoryOpToLLVMBase { + using MemoryOpToLLVMBase::MemoryOpToLLVMBase; + + MemoryOpToLLVM() : MemoryOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createMemoryOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp new file mode 100644 index 000000000000..914f56e668f8 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp @@ -0,0 +1,25 @@ +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +namespace triton { +namespace cpu { + +void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) { + pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); + pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); + pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); + // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); +} + +void registerTritonCPUToLLVMPipeline() { + PassPipelineRegistration<>("triton-cpu-to-llvmir", + "TritonCPU to LLVM conversion pipeline.", + tritonCPUToLLVMPipelineBuilder); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000000..144cb57b1115 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp @@ -0,0 +1,43 @@ +#include "TypeConverter.h" + +using namespace mlir; +using namespace mlir::triton; + +TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + addConversion([this](RankedTensorType tensorTy) -> std::optional { + if (isa(tensorTy.getElementType())) + return VectorType::get(tensorTy.getShape(), + IntegerType::get(tensorTy.getContext(), 64)); + return std::nullopt; + }); +} + +Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (isa(pointeeType)) { + // struct { + // ptr base_ptr; + // array offsets; + // array shape; + // array strides; + // } + auto tensorTy = cast(pointeeType); + auto rank = tensorTy.getShape().size(); + auto i64Ty = IntegerType::get(ctx, 64); + SmallVector types; + types.push_back(LLVM::LLVMPointerType::get(ctx)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + return LLVM::LLVMPointerType::get(ctx); +} diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h new file mode 100644 index 000000000000..35d74a9ec430 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h @@ -0,0 +1,22 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); + + Type convertTritonPointerType(triton::PointerType type); +}; + +#endif diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..9fa892b449ac --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,15 @@ +add_triton_library(TritonToTritonCPU + ConvertDotOp.cpp + ConvertElementwiseOps.cpp + ConvertMemoryOps.cpp + ConvertPtrOps.cpp + Pipeline.cpp + TypeConverter.cpp + + DEPENDS + TritonToTritonCPUPassIncGen + + LINK_LIBS PUBLIC + TritonCPUIR + MLIRVectorDialect +) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp new file mode 100644 index 000000000000..b6fbb1893202 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp @@ -0,0 +1,102 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTDOTOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class PtrConversionTarget : public ConversionTarget { +public: + explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + } +}; + +struct DotOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Value a = rewriter.getRemappedValue(op.getA()); + Value b = rewriter.getRemappedValue(op.getB()); + Value c = rewriter.getRemappedValue(op.getC()); + auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); + auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); + auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); + auto iteratorTypes = rewriter.getArrayAttr( + {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::reduction)}); + rewriter.replaceOpWithNewOp( + op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), + iteratorTypes); + return success(); + } +}; + +struct ConvertDotOp : public triton::impl::ConvertDotOpBase { + using ConvertDotOpBase::ConvertDotOpBase; + + ConvertDotOp() : ConvertDotOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + PtrConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp new file mode 100644 index 000000000000..70e8c4ed3c66 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -0,0 +1,300 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTELEMENTWISEOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ElementwiseOpConversionTarget : public ConversionTarget { +public: + explicit ElementwiseOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addDynamicallyLegalDialect( + [&](Operation *op) -> std::optional { + return converter.isLegal(op); + }); + addDynamicallyLegalDialect( + [&](Operation *op) -> std::optional { + return converter.isLegal(op); + }); + + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + } +}; + +template +struct ElementwiseOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + using typename OpConversionPattern::OpAdaptor; + + LogicalResult + matchAndRewrite(OpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + OperationState newState(op.getLoc(), ResOpT::getOperationName()); + // Convert operands. + for (auto operand : op->getOperands()) { + Value newOperand = rewriter.getRemappedValue(operand); + newState.operands.push_back(newOperand); + } + // Convert result types. + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newState.types))) { + return failure(); + } + newState.attributes = op->getAttrs(); + + auto newOp = rewriter.create(newState); + rewriter.replaceOp(op, newOp); + + return success(); + } +}; + +template <> +struct ElementwiseOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + assert(resTy); + if (auto denseAttr = dyn_cast(op.getValueAttr())) { + rewriter.replaceOpWithNewOp(op, resTy, + denseAttr.reshape(resTy)); + } else { + llvm_unreachable("Unexpected constant attribute"); + } + return success(); + } +}; + +template <> +struct ElementwiseOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcShape = dyn_cast(src.getType()).getShape(); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto dstShape = resTy.getShape(); + auto elemTy = resTy.getElementType(); + + // There are restrictions on how shape can be modified by ShapeCastOp + // when rank is changed. For now, we simply detect it and handle through + // a cast to 1D vector. Better solution may be required later. + if (canCastShape(srcShape, dstShape)) { + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), src); + } else { + SmallVector tmpShape({resTy.getNumElements()}); + auto tmp = rewriter.create( + loc, VectorType::get(tmpShape, elemTy), src); + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), tmp); + } + return success(); + } + +private: + bool canCastShape(ArrayRef src, ArrayRef dst) const { + if (src.size() == dst.size()) + return true; + if (src.size() > dst.size()) + return canCastShape(dst, src); + + size_t srcIdx = 0; + size_t dstIdx = 0; + while (srcIdx < src.size() && dstIdx < dst.size()) { + if (src[srcIdx] == 1) { + ++srcIdx; + } else { + // Source dim size should be a product of continuous dest dim sizes. + int64_t srcSize = src[srcIdx++]; + int64_t dstSize = dst[dstIdx++]; + while (dstSize < srcSize && dstIdx < dst.size()) + dstSize *= dst[dstIdx++]; + if (dstSize != srcSize) + return false; + } + } + + // Skip trailing 1s. + while (srcIdx < src.size() && src[srcIdx] == 1) + ++srcIdx; + while (dstIdx < dst.size() && dst[dstIdx] == 1) + ++dstIdx; + + return srcIdx == src.size() && dstIdx == dst.size(); + } +}; + +struct ConvertElementwiseOps + : public triton::impl::ConvertElementwiseOpsBase { + using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; + + ConvertElementwiseOps() : ConvertElementwiseOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ElementwiseOpConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + + patterns.add>( + typeConverter, context); + patterns + .add>( + typeConverter, context); + patterns.add< + ElementwiseOpConversion>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>(typeConverter, + context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertElementwiseOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp new file mode 100644 index 000000000000..1679ecc7af90 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -0,0 +1,277 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTMEMORYOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +struct LoadOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp.getLoc(); + auto mask = loadOp.getMask(); + auto ptr = loadOp.getPtr(); + auto boundaryChecks = loadOp.getBoundaryCheck(); + + if (!triton::isTensorPointerType(ptr.getType())) { + return lowerToScalarLoads(loadOp, rewriter); + } + + // TODO: support masks. + if (mask) { + llvm_unreachable("unsupported load op"); + } + + auto memRef = rewriter.getRemappedValue(ptr); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto resTy = dyn_cast( + getTypeConverter()->convertType(loadOp.getResult().getType())); + auto indices = rewriter.create(loc, ptr).getResults(); + SmallVector inBounds(rank, true); + for (auto dim : boundaryChecks) { + inBounds[dim] = false; + } + auto vecRead = rewriter.create(loc, resTy, memRef, + indices, inBounds); + rewriter.replaceOp(loadOp, vecRead); + return success(); + } + + LogicalResult lowerToScalarLoads(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + // Scalar loads and boundary checks are not expected. + assert(loadOp.getBoundaryCheck().empty()); + assert(isa(loadOp.getType())); + + auto loc = loadOp.getLoc(); + auto vecTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + auto ptrs = rewriter.getRemappedValue(loadOp.getPtr()); + auto mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) + : nullptr; + auto ptrTy = + dyn_cast(loadOp.getPtr().getType()).getElementType(); + auto cache = loadOp.getCache(); + auto evict = loadOp.getEvict(); + auto isVolatile = loadOp.getIsVolatile(); + + Value defaultVal = loadOp.getOther(); + if (!defaultVal) + defaultVal = rewriter.create( + loc, rewriter.getZeroAttr(vecTy.getElementType())); + Value dst = rewriter.create(loc, vecTy, defaultVal); + + int64_t numElems = vecTy.getNumElements(); + auto strides = computeStrides(vecTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Block *headerBlock = rewriter.getBlock(); + Block *condBlock = nullptr; + Value origDst = dst; + // Create a conditional block for load if there is a mask. + if (mask) { + condBlock = + rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(condBlock); + } + + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = + rewriter.create(loc, ptr, cache, evict, isVolatile); + dst = rewriter.create(loc, val, dst, indices); + + // Add predicate and branches. + if (mask) { + Block *footerBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + Value resDst = dst; + dst = footerBlock->addArgument(dst.getType(), dst.getLoc()); + rewriter.setInsertionPointToEnd(headerBlock); + auto predicate = rewriter.create(loc, mask, indices); + rewriter.create(loc, predicate, condBlock, + footerBlock, origDst); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, footerBlock, resDst); + rewriter.setInsertionPointToStart(footerBlock); + } + } + + rewriter.replaceOp(loadOp, dst); + + return success(); + } +}; + +struct StoreOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp.getLoc(); + auto mask = storeOp.getMask(); + auto ptr = storeOp.getPtr(); + auto boundaryChecks = storeOp.getBoundaryCheck(); + + if (!triton::isTensorPointerType(ptr.getType())) { + return lowerToScalarStores(storeOp, rewriter); + } + + // TODO: support masks. + if (mask) { + llvm_unreachable("unsupported store op"); + } + + auto value = rewriter.getRemappedValue(storeOp.getValue()); + auto memRef = rewriter.getRemappedValue(ptr); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto indices = rewriter.create(loc, ptr).getResults(); + SmallVector inBounds(rank, true); + for (auto dim : boundaryChecks) { + inBounds[dim] = false; + } + auto vecWrite = rewriter.create(loc, value, memRef, + indices, inBounds); + rewriter.replaceOp(storeOp, vecWrite); + return success(); + } + + LogicalResult lowerToScalarStores(triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + // Scalar stores and boundary checks are not expected. + assert(storeOp.getBoundaryCheck().empty()); + assert(isa(storeOp.getValue().getType())); + + auto loc = storeOp.getLoc(); + auto ptrs = rewriter.getRemappedValue(storeOp.getPtr()); + auto mask = storeOp.getMask() ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + auto tensorTy = dyn_cast(storeOp.getPtr().getType()); + auto ptrTy = tensorTy.getElementType(); + auto cache = storeOp.getCache(); + auto evict = storeOp.getEvict(); + + int64_t numElems = tensorTy.getNumElements(); + auto strides = computeStrides(tensorTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Block *headerBlock = rewriter.getBlock(); + Block *condBlock = nullptr; + // Create a conditional block for store if there is a mask. + if (mask) { + condBlock = + rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(condBlock); + } + + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + rewriter.create(loc, ptr, val, cache, evict); + + // Add predicate and branches. + if (mask) { + Block *footerBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(headerBlock); + auto predicate = rewriter.create(loc, mask, indices); + rewriter.create(loc, predicate, condBlock, + footerBlock); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, footerBlock); + rewriter.setInsertionPointToStart(footerBlock); + } + } + + rewriter.eraseOp(storeOp); + + return success(); + } +}; + +class MemoryOpConversionTarget : public ConversionTarget { +public: + explicit MemoryOpConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + // Allow only scalar loads and stores. + addDynamicallyLegalOp([](triton::LoadOp loadOp) { + return loadOp.getType().isIntOrIndexOrFloat(); + }); + addDynamicallyLegalOp([](triton::StoreOp storeOp) { + return storeOp.getValue().getType().isIntOrIndexOrFloat(); + }); + } +}; + +struct ConvertMemoryOps + : public triton::impl::ConvertMemoryOpsBase { + using ConvertMemoryOpsBase::ConvertMemoryOpsBase; + + ConvertMemoryOps() : ConvertMemoryOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + MemoryOpConversionTarget convTarget(*context); + TritonToTritonCPUTypeConverter pointerConverter; + RewritePatternSet patterns(context); + patterns.add(pointerConverter, context); + patterns.add(pointerConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertMemoryOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp new file mode 100644 index 000000000000..ade8b858bbfb --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp @@ -0,0 +1,191 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTPTROPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +unsigned getElemBitWidth(Type type) { + if (auto tensorTy = dyn_cast(type)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + if (auto vectorTy = dyn_cast(type)) + return vectorTy.getElementType().getIntOrFloatBitWidth(); + return type.getIntOrFloatBitWidth(); +} + +class PtrConversionTarget : public ConversionTarget { +public: + explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + // Allow only scalar pointer conversion. + addDynamicallyLegalOp( + [](triton::PtrToIntOp op) { return op.getType().isInteger(); }); + addDynamicallyLegalOp([](triton::IntToPtrOp op) { + return op.getSrc().getType().isInteger(); + }); + } +}; + +struct MakeRangeOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int32_t start = static_cast(op.getStart()); + int32_t end = static_cast(op.getEnd()); + assert(end >= start); + + llvm::SmallVector values; + values.reserve(end - start); + for (int32_t v = start; v < end; ++v) { + values.push_back(v); + } + + Type resTy = getTypeConverter()->convertType(op.getType()); + auto newOp = rewriter.create( + op.getLoc(), resTy, rewriter.getI32VectorAttr(values)); + + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct SplatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value val = op.getSrc(); + Type dstValType = getTypeConverter()->convertType(val.getType()); + // Cast pointer + if (isa(val.getType())) + val = rewriter + .create( + loc, getTypeConverter()->convertType(val.getType()), val) + .getResult(); + Type resType = getTypeConverter()->convertType(op.getType()); + auto cast = rewriter.create(loc, resType, val); + + rewriter.replaceOp(op, cast); + return success(); + } +}; + +struct AddPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value offset = rewriter.getRemappedValue(op.getOffset()); + unsigned offsetBitWidth = getElemBitWidth(offset.getType()); + unsigned elemBitWidth = getPointeeBitWidth(op.getPtr().getType()); + // Compute scale. i1 elements take 1 byte. + Value scale = rewriter.create( + loc, (elemBitWidth + 7) / 8, offsetBitWidth); + if (isa(offset.getType())) + scale = rewriter.create(loc, offset.getType(), scale); + offset = rewriter.create(loc, offset, scale); + offset = rewriter.create(loc, ptr.getType(), offset); + rewriter.replaceOpWithNewOp(op, ptr.getType(), ptr, offset); + return success(); + } +}; + +struct PtrToIntOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value val = rewriter.getRemappedValue(op.getSrc()); + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, val); + return success(); + } +}; + +struct IntToPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value val = rewriter.getRemappedValue(op.getSrc()); + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, val); + return success(); + } +}; + +struct ConvertPtrOps : public triton::impl::ConvertPtrOpsBase { + using ConvertPtrOpsBase::ConvertPtrOpsBase; + + ConvertPtrOps() : ConvertPtrOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + PtrConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertPtrOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp new file mode 100644 index 000000000000..16bff114ed81 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -0,0 +1,26 @@ +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +namespace triton { +namespace cpu { + +void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); + pm.addPass(mlir::triton::cpu::createConvertPtrOps()); + pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); + pm.addPass(mlir::triton::cpu::createConvertDotOp()); + // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); +} + +void registerTritonToTritonCPUPipeline() { + PassPipelineRegistration<>("triton-to-triton-cpu", + "Triton to TritonCPU conversion pipeline.", + tritonToTritonCPUPipelineBuilder); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp new file mode 100644 index 000000000000..07b2da0468ba --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp @@ -0,0 +1,51 @@ +#include "TypeConverter.h" + +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([](triton::PointerType ptrTy) -> Type { + if (triton::isTensorPointerType(ptrTy)) { + // Tensor pointer is translated into a memref + auto tensorTy = dyn_cast(ptrTy.getPointeeType()); + auto elemTy = tensorTy.getElementType(); + // TODO: use dynamic strides + SmallVector shape(tensorTy.getRank(), ShapedType::kDynamic); + return MemRefType::get(shape, elemTy); + } + return IntegerType::get(ptrTy.getContext(), 64); + }); + addConversion([this](RankedTensorType tensorTy) -> Type { + Type elemTy = convertType(tensorTy.getElementType()); + return VectorType::get(tensorTy.getShape(), elemTy); + }); + + // Converted ops produce vectors instead of tensors. Provide conversion + // here for users. Also, convert pointers when required. + addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> std::optional { + if (isa(type)) + return builder.create(loc, type, inputs); + return builder.create(loc, type, inputs) + .getResult(0); + }); + + // Converted loads and stores consume memrefs instead of pointers, use extract + // op to get them. Also, provide conversion for vector users and pointer + // casts. + addTargetMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> std::optional { + if (type.isInteger() && isa(inputs.front().getType())) + return builder.create(loc, type, inputs); + if (isa(type)) + return builder.create(loc, type, inputs) + .getResult(0); + if (isa(type)) + return builder.create(loc, type, inputs); + llvm_unreachable("Unexpected target materizalization"); + }); +} diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h new file mode 100644 index 000000000000..cb89f0886c60 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h @@ -0,0 +1,19 @@ +#ifndef TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H + +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonToTritonCPUTypeConverter : public TypeConverter { +public: + using TypeConverter::convertType; + + TritonToTritonCPUTypeConverter(); + + Type convertTritonPointerType(triton::PointerType type); +}; + +#endif diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 302951d04d59..efc949d6f4a1 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -1,9 +1,20 @@ +#include "TritonCPUToLLVM/Passes.h" +#include "TritonToTritonCPU/Passes.h" + +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "triton/Conversion/TritonCPUToLLVM/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "llvm/IR/Constants.h" #include "llvm/Support/TargetSelect.h" + #include #include #include @@ -14,8 +25,26 @@ namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { using namespace mlir::triton; - m.def("add_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); + // m.def("add_to_llvmir", [](mlir::PassManager &pm) { + // pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); + // }); + m.def("add_triton_to_triton_cpu_pipeline", [](mlir::PassManager &pm) { + mlir::triton::cpu::tritonToTritonCPUPipelineBuilder(pm); + }); + m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { + mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); + }); + m.def("add_vector_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertVectorToLLVMPass()); + }); + m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + }); + m.def("add_math_to_libm", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertMathToLibmPass()); + }); + m.def("add_func_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertFuncToLLVMPass()); }); } @@ -25,8 +54,18 @@ void init_triton_cpu(py::module &&m) { m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); + + m.def("find_kernel_names", [](mlir::ModuleOp &mod) { + std::vector res; + mod.walk([&](mlir::FunctionOpInterface funcOp) { + if (funcOp.getVisibility() == mlir::SymbolTable::Visibility::Public) + res.push_back(funcOp.getName().str()); + }); + return res; + }); } From e603b009433cb979c5ebad493ead3d00c50b8046 Mon Sep 17 00:00:00 2001 From: shanenay Date: Fri, 17 May 2024 14:48:48 -0700 Subject: [PATCH 010/165] Revert unreviewed changes. (#5) Co-authored-by: Shane Nay --- bin/RegisterTritonDialects.h | 8 - .../Conversion/TritonCPUToLLVM/CMakeLists.txt | 2 +- .../Dialect/TritonCPU/IR/TritonCPUAttrDefs.td | 5 +- .../Dialect/TritonCPU/IR/TritonCPUDialect.td | 3 - .../Dialect/TritonCPU/IR/TritonCPUOps.td | 51 --- lib/Conversion/CMakeLists.txt | 4 +- lib/Dialect/TritonCPU/IR/Dialect.cpp | 37 +- python/src/llvm.cc | 67 ---- python/src/passes.cc | 5 +- python/test/unit/language/test_core.py | 31 -- third_party/cpu/CMakeLists.txt | 5 - third_party/cpu/backend/compiler.py | 73 ++-- third_party/cpu/backend/driver.cpp | 224 ------------ third_party/cpu/backend/driver.py | 326 ++---------------- third_party/cpu/include/CMakeLists.txt | 2 - .../include/TritonCPUToLLVM/CMakeLists.txt | 3 - .../cpu/include/TritonCPUToLLVM/Passes.h | 36 -- .../cpu/include/TritonCPUToLLVM/Passes.td | 46 --- .../include/TritonToTritonCPU/CMakeLists.txt | 3 - .../cpu/include/TritonToTritonCPU/Passes.h | 37 -- .../cpu/include/TritonToTritonCPU/Passes.td | 62 ---- third_party/cpu/lib/CMakeLists.txt | 2 - .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 13 - .../cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp | 276 --------------- .../TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp | 98 ------ .../lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp | 277 --------------- .../cpu/lib/TritonCPUToLLVM/Pipeline.cpp | 25 -- .../cpu/lib/TritonCPUToLLVM/TypeConverter.cpp | 43 --- .../cpu/lib/TritonCPUToLLVM/TypeConverter.h | 22 -- .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 15 - .../lib/TritonToTritonCPU/ConvertDotOp.cpp | 102 ------ .../ConvertElementwiseOps.cpp | 300 ---------------- .../TritonToTritonCPU/ConvertMemoryOps.cpp | 277 --------------- .../lib/TritonToTritonCPU/ConvertPtrOps.cpp | 191 ---------- .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 26 -- .../lib/TritonToTritonCPU/TypeConverter.cpp | 51 --- .../cpu/lib/TritonToTritonCPU/TypeConverter.h | 19 - third_party/cpu/triton_cpu.cc | 45 +-- 38 files changed, 70 insertions(+), 2742 deletions(-) delete mode 100644 third_party/cpu/backend/driver.cpp delete mode 100644 third_party/cpu/include/CMakeLists.txt delete mode 100644 third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt delete mode 100644 third_party/cpu/include/TritonCPUToLLVM/Passes.h delete mode 100644 third_party/cpu/include/TritonCPUToLLVM/Passes.td delete mode 100644 third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt delete mode 100644 third_party/cpu/include/TritonToTritonCPU/Passes.h delete mode 100644 third_party/cpu/include/TritonToTritonCPU/Passes.td delete mode 100644 third_party/cpu/lib/CMakeLists.txt delete mode 100644 third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt delete mode 100644 third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp delete mode 100644 third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp delete mode 100644 third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp delete mode 100644 third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp delete mode 100644 third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp delete mode 100644 third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h delete mode 100644 third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt delete mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp delete mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp delete mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp delete mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp delete mode 100644 third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp delete mode 100644 third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp delete mode 100644 third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index ca922e824793..17737e1096c6 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -17,8 +17,6 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" -#include "cpu/include/TritonCPUToLLVM/Passes.h" -#include "cpu/include/TritonToTritonCPU/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" @@ -71,12 +69,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); - // CPU passes - mlir::triton::cpu::registerTritonToTritonCPUPasses(); - mlir::triton::cpu::registerTritonToTritonCPUPipeline(); - mlir::triton::cpu::registerTritonCPUToLLVMPasses(); - mlir::triton::cpu::registerTritonCPUToLLVMPipeline(); - // TODO: register Triton & TritonGPU passes registry.insert traits = [], string baseCppClass = "::mlir::Attribute"> : AttrDef { - let description = [{TritonCPU attr.}]; - let attrName = "triton.cpu." # attrMnemonic; + let description = [{ + WIP... + }]; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td index 260db2743046..9ccac13f0b58 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td @@ -17,7 +17,6 @@ def TritonCPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", "tensor::TensorDialect", - "mlir::memref::MemRefDialect", ]; let extraClassDeclaration = [{ @@ -25,8 +24,6 @@ def TritonCPU_Dialect : Dialect { }]; let useDefaultTypePrinterParser = 1; - let useDefaultAttributePrinterParser = 1; - let usePropertiesForAttributes = 1; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index bb7417ebd03e..16d9e433e899 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -7,57 +7,6 @@ include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" -include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" // Pure -include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType -include "mlir/Interfaces/DestinationStyleOpInterface.td" -include "mlir/Interfaces/ViewLikeInterface.td" - -class TTC_Op traits = []> : - Op { -} - -def TTC_ExtractMemRefOp : TTC_Op<"extract_memref", [NoMemoryEffect]> { - let summary = "Extract base memref from a block pointer"; - - let description = [{ - Extract base memref from a block pointer. It covers whole base tensor memory, - not only the block referenced. Base pointer, shape, and strides are used - in the resulting memref. Offsets and block shape are ignored. - - }]; - - let arguments = (ins TT_TensorPtr:$src); - - let results = (outs AnyRankedOrUnrankedMemRef:$result); - - let hasCanonicalizer = 1; - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -} - -def TTC_ExtractIndicesOp : TTC_Op<"extract_indices", [NoMemoryEffect]> { - let summary = "Extract indices from a block pointer."; - - let description = [{ - Extract indices that can be used to access the block using its base memref. - Indices are supposed to be used for vector loads/stores with the base - memref extracted from the same block pointer. - }]; - - let arguments = (ins TT_TensorPtr:$src); - - let results = (outs Variadic:$result); - - let builders = [ - OpBuilder<(ins "Value":$src)> - ]; - - let hasCanonicalizer = 1; - - let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; -} #endif diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 83db4ae41607..5c3aa2c1a827 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,4 @@ -#add_subdirectory(TritonToTritonCPU) +add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) -#add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp index e5eb53caf686..e28a65358dca 100644 --- a/lib/Dialect/TritonCPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -2,19 +2,16 @@ #include -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/TypeSwitch.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" - using namespace mlir; -using namespace mlir::triton; using namespace mlir::triton::cpu; //===----------------------------------------------------------------------===// @@ -23,35 +20,6 @@ using namespace mlir::triton::cpu; #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc" -void ExtractMemRefOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) {} - -void ExtractIndicesOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) {} - -/// Parse an attribute registered to this dialect. -::mlir::Attribute -TritonCPUDialect::parseAttribute(::mlir::DialectAsmParser &parser, - ::mlir::Type type) const { - llvm_unreachable("parse stub called"); -} - -/// Print an attribute registered to this dialect. -void TritonCPUDialect::printAttribute(::mlir::Attribute attr, - ::mlir::DialectAsmPrinter &os) const { - llvm_unreachable("print stub called"); -} - -void ExtractIndicesOp::build(::mlir::OpBuilder &builder, - ::mlir::OperationState &state, Value src) { - assert(triton::isTensorPointerType(src.getType()) && - "Unexecpeted source type"); - auto tensorTy = dyn_cast( - dyn_cast(src.getType()).getPointeeType()); - SmallVector resTypes(tensorTy.getRank(), builder.getIndexType()); - build(builder, state, resTypes, src); -} - void TritonCPUDialect::initialize() { registerTypes(); @@ -66,9 +34,6 @@ void TritonCPUDialect::initialize() { >(); } -#define GET_OP_CLASSES -#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" - // verify TritonCPU ops LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { diff --git a/python/src/llvm.cc b/python/src/llvm.cc index ca7b9c911643..c86bf671a7df 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -3,8 +3,6 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" -#include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -23,7 +21,6 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" -#include "llvm/TargetParser/Host.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/Instrumentation/AddressSanitizer.h" @@ -406,70 +403,6 @@ void init_triton_llvm(py::module &&m) { py::arg("flags") = std::vector{}, py::arg("enable_fp_fusion") = false); - m.def("set_host_target", [](llvm::Module *mod) { - mod->setTargetTriple(llvm::sys::getDefaultTargetTriple()); - std::string error; - auto target = - llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error); - std::unique_ptr machine{target->createTargetMachine( - mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {}, - llvm::Reloc::PIC_)}; - mod->setDataLayout(machine->createDataLayout()); - }); - - m.def( - "translate_to_host_asm", - [](std::string llvmIR) -> py::object { - std::string res; - { - // when allow_threads goes out of scope, gil will be released - py::gil_scoped_release allow_threads; - // create LLVM module from C++ - llvm::LLVMContext context; - std::unique_ptr buffer = - llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); - llvm::SMDiagnostic error; - std::unique_ptr module = - llvm::parseIR(buffer->getMemBufferRef(), error, context); - if (!module) { - llvm::report_fatal_error( - "failed to parse IR: " + error.getMessage() + - "lineno: " + std::to_string(error.getLineNo())); - } - res = translateLLVMIRToASM( - *module, llvm::sys::getDefaultTargetTriple(), - llvm::sys::getHostCPUName().str(), "", {}, false, false); - } - return py::str(res); - }, - ret::take_ownership); - - m.def( - "translate_to_bc", - [](const std::string llvmIR) -> py::object { - py::gil_scoped_release allow_threads; - // create LLVM module - llvm::LLVMContext context; - std::unique_ptr buffer = - llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); - llvm::SMDiagnostic error; - std::unique_ptr module = - llvm::parseIR(buffer->getMemBufferRef(), error, context); - if (!module) { - llvm::report_fatal_error( - "failed to parse IR: " + error.getMessage() + - "lineno: " + std::to_string(error.getLineNo())); - } - // Write bitcode to a buffer. - llvm::SmallVector buf; - llvm::BitcodeWriter writer(buf); - writer.writeModule(*module); - writer.writeStrtab(); - std::string bitcode(buf.begin(), buf.end()); - return py::bytes(bitcode); - }, - ret::take_ownership); - m.def( "translate_to_asm", [](std::string llvmIR, std::string triple, std::string proc, diff --git a/python/src/passes.cc b/python/src/passes.cc index 9e34f6ad7fed..c365aaf43589 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -45,8 +45,8 @@ void init_triton_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", createConvertTritonToTritonGPUPass, const std::string &, int, int, int); - // ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", - // createConvertTritonToTritonCPUPass); + ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", + createConvertTritonToTritonCPUPass); } void init_triton_passes_ttgpuir(py::module &&m) { @@ -89,7 +89,6 @@ void init_triton_passes_convert(py::module &&m) { ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); - ADD_PASS_WRAPPER_0("add_math_to_llvmir", createConvertMathToLLVMPass); } void init_triton_passes_llvmir(py::module &&m) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b42d715fea7b..582856774785 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -304,7 +304,6 @@ def filter_layouts(layouts): return [l for l in layouts if is_layout_applicable(l)] -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) def test_empty_kernel(dtype_x, device): @@ -544,7 +543,6 @@ def test_dtype_codegen(): # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -605,7 +603,6 @@ def promote_to_fp32(dtype_x, dtype_y): test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) def test_addptr(dtype, order, device): @@ -632,7 +629,6 @@ def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): np.testing.assert_allclose(y, to_numpy(y_tri)) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y", [ # (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes @@ -653,7 +649,6 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas) -@pytest.mark.cpu def test_unsigned_name_mangling(device): # Test that uint32 and int32 are mangled differently by the compiler SIZE = 128 @@ -690,7 +685,6 @@ def kernel(O1, O2, X, Y, SIZE: tl.constexpr): # test bitwise ops # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -715,7 +709,6 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes @@ -738,7 +731,6 @@ def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): ops = ['==', '!=', '>', '<', '>=', '<='] -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "dtype_x, dtype_y, op, mode_x, mode_y", @@ -763,7 +755,6 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # --------------- # test broadcast # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) def test_broadcast(dtype, device): @@ -798,7 +789,6 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con # ---------- -@pytest.mark.cpu @pytest.mark.interpreter def test_slice(device): @@ -830,7 +820,6 @@ def slice_kernel(XBLOCK: tl.constexpr): # ------------------ -@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_slice(device): dst = torch.empty(128, device=device) @@ -846,7 +835,6 @@ def _kernel(dst): # ---------------- # test expand_dims # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims(device): @@ -895,7 +883,6 @@ def expand_dims_kernel(dummy, N: tl.constexpr): expand_dims_kernel[(1, )](dummy_tensor, N) -@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims_error_cases(device): @@ -959,7 +946,6 @@ def duplicate_dim2(dummy, N: tl.constexpr): # ---------------------------- # test invalid program id axis # ---------------------------- -@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_pid_axis(device): dst = torch.empty(128, device=device) @@ -976,7 +962,6 @@ def _kernel(dst): # --------------- # test where # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1029,7 +1014,6 @@ def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl. assert (z == to_numpy(z_tri)).all() -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_where_broadcast(num_ctas, device): @@ -1074,7 +1058,6 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr", [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') @@ -1089,7 +1072,6 @@ def test_unary_op(dtype_x, expr, num_ctas, device): # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) @@ -1100,7 +1082,6 @@ def test_math_op(dtype_x, expr, x, device): _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_erf_op(dtype, device): @@ -1122,7 +1103,6 @@ def kernel(Z, X, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_fma_op(dtype, device): @@ -1148,7 +1128,6 @@ def kernel(Z, X, Y, W, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1161,7 +1140,6 @@ def test_math_divide_op(expr, num_ctas, device): # ------------- # test precise math # ------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr_prec, expr_ref", [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), @@ -1202,7 +1180,6 @@ def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_abs(dtype_x, device): @@ -1248,7 +1225,6 @@ def abs_kernel(X, Z, SIZE: tl.constexpr): # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter def test_shapes_as_params(device): @@ -1391,7 +1367,6 @@ def tuples_fn(a, b): a * b -@pytest.mark.cpu @pytest.mark.interpreter def test_tuples(device): @@ -1484,7 +1459,6 @@ def noinline_multi_values_fn(x, y, Z): tl.store(Z, z) -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) def test_noinline(mode, device): @@ -1807,7 +1781,6 @@ def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): # --------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ @@ -4757,7 +4730,6 @@ def kernel(VALUE, X): # ---------------- -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @@ -4795,7 +4767,6 @@ def kernel(Z, X, Y): np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) -@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_shape(device): @@ -4809,7 +4780,6 @@ def kernel(X): np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) -@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_scalar_shape(device): @@ -4827,7 +4797,6 @@ def kernel(X, s): reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] -@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("formats", reshape_list) def test_reshape(formats, device): diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index d8be71ad6c11..683889547b0a 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,8 +1,3 @@ -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) -include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) -add_subdirectory(include) -add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM) - target_link_libraries(TritonCPU PUBLIC MLIRMathToLibm) endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 357b5f448fe9..3c293cdf468f 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -4,7 +4,7 @@ import re from dataclasses import dataclass -from typing import Any, Tuple +from typing import Any from triton._C.libtriton import cpu, ir, llvm, passes from triton.backends.compiler import BaseBackend, GPUTarget @@ -20,8 +20,6 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False - allowed_dot_input_precisions: Tuple[str] = ("ieee",) - allow_fp8e4nv: bool = False # TODO: We may introduce CPU-specific options like # of cores. @@ -42,7 +40,7 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) - self.binary_ext = "bc" + self.binary_ext = "exe" def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -64,6 +62,7 @@ def make_ttir(mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) + passes.ttir.add_rewrite_tensor_pointer(pm) passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) @@ -78,34 +77,33 @@ def make_ttcir(mod, metadata, opt): # TTIR -> TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - cpu.passes.ttcpuir.add_triton_to_triton_cpu_pipeline(pm) + passes.ttir.add_convert_to_ttcpuir(pm) + + # + # TODO: + # + passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) - passes.common.add_canonicalizer(pm) pm.run(mod) - metadata["cluster_dims"] = (opt.cluster_dims[0], opt.cluster_dims[1], opt.cluster_dims[2]) return mod @staticmethod def make_llir(src, metadata, options): - # warp-specialization mutates num_warps - num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") - if num_warp_groups is not None: - metadata["num_warps"] *= num_warp_groups - metadata["threads_per_warp"] = 1 mod = src # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) - cpu.passes.ttcpuir.add_triton_cpu_to_llvmir_pipeline(pm) - passes.convert.add_math_to_llvmir(pm) - cpu.passes.ttcpuir.add_math_to_libm(pm) - cpu.passes.ttcpuir.add_vector_to_llvmir(pm) - cpu.passes.ttcpuir.add_memref_to_llvmir(pm) + + cpu.passes.ttcpuir.add_to_llvmir(pm) + passes.common.add_canonicalizer(pm) + passes.common.add_cse(pm) + + passes.convert.add_scf_to_cf(pm) + passes.convert.add_cf_to_llvmir(pm) passes.convert.add_arith_to_llvmir(pm) - cpu.passes.ttcpuir.add_func_to_llvmir(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) @@ -113,40 +111,45 @@ def make_llir(src, metadata, options): passes.llvmir.add_di_scope(pm) pm.run(mod) - # Find kernel fn - kernel_names = cpu.find_kernel_names(mod) - assert len(kernel_names) == 1, f"expected exactly 1 kernel in a module, got {kernel_names}" - # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() llvm_mod = llvm.to_module(mod, context) - llvm.set_host_target(llvm_mod) - #if options.extern_libs: - # paths = [path for (name, path) in options.extern_libs] - # llvm.link_extern_libs(llvm_mod, paths) + + # TODO: + if not llvm_mod: + metadata["shared"] = 0 + return src + + if options.extern_libs: + paths = [path for (name, path) in options.extern_libs] + llvm.link_extern_libs(llvm_mod, paths) llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) - # Get some metadata + + # CPU doesn't have SMEM, but just to make it work for now. metadata["shared"] = 0 - metadata["name"] = kernel_names[0] + + # Cleanup ret = str(llvm_mod) del llvm_mod del context return ret @staticmethod - def make_bc(src, metadata, options): - if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1": - print("********** Module ASM **********") - print(llvm.translate_to_host_asm(src)) - ret = llvm.translate_to_bc(src) - return ret + def make_exe(src, metadata, options): + # Just a quick hack while developing the backend. + names = re.findall(r"\s+define void @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) + assert len(names) == 1 + metadata["name"] = names[0] + + # TODO: Call llc to create an executable. + return src def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) - stages["bc"] = lambda src, metadata: self.make_bc(src, metadata, options) + stages["exe"] = lambda src, metadata: self.make_exe(src, metadata, options) @functools.lru_cache() def hash(self): diff --git a/third_party/cpu/backend/driver.cpp b/third_party/cpu/backend/driver.cpp deleted file mode 100644 index babff3dfdebe..000000000000 --- a/third_party/cpu/backend/driver.cpp +++ /dev/null @@ -1,224 +0,0 @@ -//===- driver.cpp ---------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "llvm/Bitcode/BitcodeReader.h" -#include "llvm/ExecutionEngine/Orc/CompileUtils.h" -#include "llvm/ExecutionEngine/Orc/Core.h" -#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" -#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" -#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" -#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" -#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" -#include "llvm/ExecutionEngine/SectionMemoryManager.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/TargetSelect.h" - -#include -#include -#include -#include -#include -#include -#include - -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include - -static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { - int device_id; - if (!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - - return Py_BuildValue("{s:i}", "max_shared_mem", 0); -} - -bool getBoolEnv(const std::string &env) { - const char *s = std::getenv(env.c_str()); - std::string str(s ? s : ""); - std::transform(str.begin(), str.end(), str.begin(), - [](unsigned char c) { return std::tolower(c); }); - return (str == "on" || str == "true" || str == "1"); -} - -llvm::orc::ThreadSafeContext &getThreadSafeContext() { - static llvm::orc::ThreadSafeContext tsc; - static std::once_flag init_flag; - std::call_once(init_flag, []() { - auto context = std::make_unique(); - tsc = llvm::orc::ThreadSafeContext(std::move(context)); - }); - return tsc; -} - -std::string llvmErrToString(const llvm::Error &err) { - std::string res; - llvm::raw_string_ostream os(res); - os << err; - return res; -}; - -struct CompiledKernel { - std::unique_ptr execution_session; - std::unique_ptr data_layout; - std::unique_ptr mangle; - std::unique_ptr object_layer; - std::unique_ptr compiler_layer; - llvm::orc::JITDylib *dylib = nullptr; - - CompiledKernel() = default; - CompiledKernel(CompiledKernel &&) = default; - - ~CompiledKernel() { - if (execution_session) - llvm::cantFail(execution_session->endSession()); - } -}; - -std::vector> compiled_kernels; - -static PyObject *loadBitcode(PyObject *self, PyObject *args) { - const char *name; - int shared; - PyObject *py_bytes; - int devId; - - if (!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &devId)) { - std::cerr << "loadBitcode arg parse failed" << std::endl; - return NULL; - } - - std::string kernel_name = name; - size_t binary_size = PyBytes_Size(py_bytes); - const char *binary_ptr = PyBytes_AsString(py_bytes); - - llvm::LLVMContext context; - auto buf = llvm::MemoryBuffer::getMemBuffer( - llvm::StringRef(binary_ptr, binary_size)); - auto mod = llvm::parseBitcodeFile(*buf, context); - if (!mod) { - std::cerr << "Failed to parse LLVM bitcode module" << std::endl; - return NULL; - } - - if (getBoolEnv("MLIR_ENABLE_DUMP")) { - llvm::errs() << "********** Loaded Module (kernel_name=" << name - << ") **********\n" - << **mod << "\n"; - } - - auto init_err = llvm::InitializeNativeTarget(); - if (init_err) { - std::cerr << "Failed to initialize native target." << std::endl; - return NULL; - } - - llvm::InitializeNativeTargetAsmPrinter(); - llvm::InitializeNativeTargetAsmParser(); - - auto self_epc = - llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create()); - - auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); - if (!detect_host_res) { - std::cerr << "Failed to initialize JITTargetMachineBuilder: " - << llvmErrToString(detect_host_res.takeError()); - return NULL; - } - llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); - - auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); - if (!data_layout_res) { - std::cerr << "Failed to initialize data layout: " - << llvmErrToString(data_layout_res.takeError()); - return NULL; - } - - CompiledKernel kernel; - kernel.execution_session = - std::make_unique(std::move(self_epc)); - kernel.data_layout = - std::make_unique(std::move(*data_layout_res)); - kernel.mangle = std::make_unique( - *kernel.execution_session, *kernel.data_layout); - kernel.object_layer = std::make_unique( - *kernel.execution_session, - []() { return std::make_unique(); }); - kernel.compiler_layer = std::make_unique( - *kernel.execution_session, *kernel.object_layer, - std::make_unique(std::move(tmb))); - - auto dylib_res = kernel.execution_session->createJITDylib("
"); - if (!dylib_res) { - std::cerr << "Failed to create initialize JITDylib: " - << llvmErrToString(dylib_res.takeError()); - return NULL; - } - - kernel.dylib = &(*dylib_res); - kernel.dylib->addGenerator(llvm::cantFail( - llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( - kernel.data_layout->getGlobalPrefix()))); - - // Compile module. - (**mod).setDataLayout(*kernel.data_layout); - llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); - auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm)); - if (err) { - std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); - return NULL; - } - - // Find kernel function pointer. - auto lookup_res = - kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name)); - if (!lookup_res) { - std::cerr << "Failed to find function " << std::string(name) - << "\nError: " << llvmErrToString(lookup_res.takeError()); - return NULL; - } - uint64_t fn_ptr = lookup_res->getAddress().getValue(); - - compiled_kernels.push_back( - std::make_unique(std::move(kernel))); - auto *kernel_ptr = compiled_kernels.back().get(); - - return Py_BuildValue("(KKii)", reinterpret_cast(kernel_ptr), - reinterpret_cast(fn_ptr), 0, 0); -} - -static PyObject *initContext(PyObject *self, PyObject *args) { - return Py_BuildValue("(K)", (uint64_t)0); -} - -static PyObject *initDevices(PyObject *self, PyObject *args) { - return Py_BuildValue("(i)", 1); -} - -static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBitcode, METH_VARARGS, - "Load provided SPV into ZE driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, - "Get the properties for a given device"}, - {NULL, NULL, 0, NULL} // sentinel -}; - -static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cpu_utils", - NULL, // documentation - -1, // size - ModuleMethods}; - -PyMODINIT_FUNC PyInit_cpu_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; - } - PyModule_AddFunctions(m, ModuleMethods); - return m; -} diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 743684d2640f..3f3816a99b9f 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,100 +1,5 @@ -import os -import hashlib -import tempfile -from pathlib import Path -from triton.runtime.build import _build -from triton.runtime.cache import get_cache_manager -from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget - -dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") -llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm") -llvm_root = os.path.expanduser(llvm_root) -llvm_dirs = os.listdir(llvm_root) -if len(llvm_dirs) == 1: - llvm_root = os.path.join(llvm_root, llvm_dirs[0]) -include_dir = [ - os.path.join(dirname, "include"), - os.path.join(llvm_root, "include"), -] -library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] -libraries = [ - "LLVMOrcJIT", - "LLVMPasses", - "LLVMX86CodeGen", - "LLVMX86AsmParser", - "LLVMX86Desc", - "LLVMX86Info", - "LLVMGlobalISel", - "LLVMSelectionDAG", - "LLVMHipStdPar", - "LLVMCoroutines", - "LLVMipo", - "LLVMFrontendOpenMP", - "LLVMInstrumentation", - "LLVMAsmPrinter", - "LLVMCodeGen", - "LLVMObjCARCOpts", - "LLVMLinker", - "LLVMVectorize", - "LLVMScalarOpts", - "LLVMInstCombine", - "LLVMFrontendOffloading", - "LLVMExecutionEngine", - "LLVMAggressiveInstCombine", - "LLVMTransformUtils", - "LLVMTarget", - "LLVMRuntimeDyld", - "LLVMJITLink", - "LLVMIRPrinter", - "LLVMBitWriter", - "LLVMAnalysis", - "LLVMProfileData", - "LLVMSymbolize", - "LLVMDebugInfoDWARF", - "LLVMObject", - "LLVMTextAPI", - "LLVMMCParser", - "LLVMMCDisassembler", - "LLVMMC", - "LLVMIRReader", - "LLVMCFGuard", - "LLVMBitReader", - "LLVMAsmParser", - "LLVMCore", - "LLVMBinaryFormat", - "LLVMOrcTargetProcess", - "LLVMTargetParser", - "LLVMRemarks", - "LLVMOrcShared", - "LLVMOption", - "LLVMDebugInfoCodeView", - "LLVMCodeGenTypes", - "LLVMBitstreamReader", - "LLVMSupport", - "LLVMDemangle", - "stdc++", -] - - -def compile_module_from_src(src, name): - key = hashlib.md5(src.encode("utf-8")).hexdigest() - cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - src_path = os.path.join(tmpdir, "main.cpp") - with open(src_path, "w") as f: - f.write(src) - so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) - import importlib.util - spec = importlib.util.spec_from_file_location(name, cache_path) - mod = importlib.util.module_from_spec(spec) - spec.loader.exec_module(mod) - return mod - +from triton.backends.driver import CPUDriverBase # ------------------------ # Utils @@ -110,12 +15,22 @@ def __new__(cls): def __init__(self): pass - dirname = os.path.dirname(os.path.realpath(__file__)) - mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils") - self.load_binary = mod.load_binary - def get_device_properties(self, *args): - return {"max_shared_mem": 0} + @staticmethod + def get_device_properties(device): + # This is just dummy for now. We will need to implement driver.c. + return { + "max_shared_mem": 0, + "multiprocessor_count": 0, + "sm_clock_rate": 0, + "mem_clock_rate": 0, + "mem_bus_width": 0, + } + + @staticmethod + def load_binary(name, kernel_asm, shared, device): + # This is just dummy for now. We will need to implement driver.c. + return (None, kernel_asm, 0, 0) # ------------------------ @@ -123,228 +38,27 @@ def get_device_properties(self, *args): # ------------------------ -def ty_to_cpp(ty): - if ty[0] == '*': - return "void*" - return { - "i1": "int32_t", - "i8": "int8_t", - "i16": "int16_t", - "i32": "int32_t", - "i64": "int64_t", - "u1": "uint32_t", - "u8": "uint8_t", - "u16": "uint16_t", - "u32": "uint32_t", - "u64": "uint64_t", - "fp16": "float", - "bf16": "float", - "fp32": "float", - "f32": "float", - "fp64": "double", - }[ty] - - def make_launcher(constants, signature, ids): - # Record the end of regular arguments; - # subsequent arguments are architecture-specific descriptors. - arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) - arg_types = (', '.join(f"{ty_to_cpp(ty)}" for i, ty in signature.items()) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" - - def _extracted_type(ty): - if ty[0] == '*': - return "PyObject*" - return ty_to_cpp(ty) - - def format_of(ty): - return { - "PyObject*": "O", - "float": "f", - "double": "d", - "long": "l", - "int8_t": "b", - "int16_t": "h", - "int32_t": "i", - "int64_t": "l", - "uint8_t": "B", - "uint16_t": "H", - "uint32_t": "I", - "uint64_t": "K", - }[ty] - - args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) - format = "iiiOKOOOO" + args_format - args_list = ', '.join(f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' - - # generate glue code - src = f""" -#include -#include -#include -#include - -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include -#include - -using kernel_ptr_t = void(*)({arg_types}); - -typedef struct _DevicePtrInfo {{ - void* dev_ptr; - bool valid; -}} DevicePtrInfo; - -static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ - DevicePtrInfo ptr_info; - ptr_info.dev_ptr = 0; - ptr_info.valid = true; - if (PyLong_Check(obj)) {{ - ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); - return ptr_info; - }} - if (obj == Py_None) {{ - // valid nullptr - return ptr_info; - }} - PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); - if(ptr){{ - PyObject *empty_tuple = PyTuple_New(0); - PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); - Py_DECREF(empty_tuple); - Py_DECREF(ptr); - if (!PyLong_Check(ret)) {{ - PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); - ptr_info.valid = false; - return ptr_info; - }} - ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); - if(!ptr_info.dev_ptr) {{ - return ptr_info; - }} - Py_DECREF(ret); // Thanks ChatGPT! - return ptr_info; - }} - PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); - ptr_info.valid = false; - return ptr_info; -}} - -static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - // TODO: add OMP pragmas to run in parallel - for (uint32_t z = 0; z < gridZ; ++z) {{ - for (uint32_t y = 0; y < gridY; ++y) {{ - for (uint32_t x = 0; x < gridX; ++x) {{ - (*kernel_ptr)({args_list + ', ' if len(arg_decls) > 0 else ''} x, y, z); - }} - }} - }} -}} - -static PyObject* launch(PyObject* self, PyObject* args) {{ - - - int gridX, gridY, gridZ; - PyObject *launch_enter_hook = NULL; - PyObject *launch_exit_hook = NULL; - PyObject *kernel_metadata = NULL; - PyObject *launch_metadata = NULL; - PyObject *py_obj_stream; - void* pKrnl; - - {' '.join([f"{_extracted_type(ty)} arg{i}; " for i, ty in signature.items()])} - if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, - &kernel_metadata, &launch_metadata, - &launch_enter_hook, &launch_exit_hook {', ' + arg_ptrs_list if len(signature) > 0 else ''})) {{ - return NULL; - }} - - void *pStream = PyLong_AsVoidPtr(py_obj_stream); - kernel_ptr_t kernel_ptr = reinterpret_cast(pKrnl); - - // extract launch metadata - if (launch_enter_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_enter_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - run_omp_kernels(gridX, gridY, gridZ, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); - - if(launch_exit_hook != Py_None){{ - PyObject* args = Py_BuildValue("(O)", launch_metadata); - PyObject* ret = PyObject_CallObject(launch_exit_hook, args); - Py_DECREF(args); - if (!ret) - return NULL; - }} - - if (PyErr_Occurred()) {{ - return NULL; - }} - - // return None - Py_INCREF(Py_None); - return Py_None; -}} - -static PyMethodDef ModuleMethods[] = {{ - {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, - {{NULL, NULL, 0, NULL}} // sentinel -}}; - -static struct PyModuleDef ModuleDef = {{ - PyModuleDef_HEAD_INIT, - \"__triton_cpu_launcher\", - NULL, //documentation - -1, //size - ModuleMethods -}}; - -PyMODINIT_FUNC PyInit___triton_cpu_launcher(void) {{ - PyObject *m = PyModule_Create(&ModuleDef); - if(m == NULL) {{ - return NULL; - }} - PyModule_AddFunctions(m, ModuleMethods); - return m; -}} -""" - return src + pass class CPULauncher(object): def __init__(self, src, metadata): - ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} - constants = src.constants if hasattr(src, "constants") else dict() - cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i - constants = {cst_key(key): value for key, value in constants.items()} - signature = {cst_key(key): value for key, value in src.signature.items()} - src = make_launcher(constants, signature, ids) - mod = compile_module_from_src(src, "__triton_cpu_launcher") - self.launch = mod.launch + # TODO: + self.launch = lambda *args, **kwargs: None def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) -class CPUDriver(DriverBase): +class CPUDriver(CPUDriverBase): def __init__(self): self.utils = CPUUtils() self.launcher_cls = CPULauncher super().__init__() - def get_current_device(self): - return 0 - - def get_current_stream(self, device): - return 0 - def get_current_target(self): # Capability and warp size are zeros for CPU. # TODO: GPUTarget naming isn't obviously good. diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt deleted file mode 100644 index fc9a19e52b0d..000000000000 --- a/third_party/cpu/include/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(TritonCPUToLLVM) -add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt deleted file mode 100644 index 64b36523d35d..000000000000 --- a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) -add_public_tablegen_target(TritonCPUConversionPassIncGen) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h deleted file mode 100644 index 74f74b00870c..000000000000 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H -#define TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include - -namespace mlir { - -class ModuleOp; -template class OperationPass; - -namespace triton { -namespace cpu { - -#define GEN_PASS_DECL -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" - -std::unique_ptr> createFuncOpToLLVMPass(); -std::unique_ptr> createMemoryOpToLLVMPass(); -std::unique_ptr> createGetProgramIdOpToLLVMPass(); - -void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); -void registerTritonCPUToLLVMPipeline(); - -#define GEN_PASS_REGISTRATION -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" - -} // namespace cpu -} // namespace triton - -} // namespace mlir - -#endif diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td deleted file mode 100644 index c75b58b572f1..000000000000 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef TRITONCPU_CONVERSION_PASSES -#define TRITONCPU_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - -def FuncOpToLLVM : Pass<"triton-cpu-func-op-to-llvm", "mlir::ModuleOp"> { - let summary = "Convert FuncOp to LLVM for CPU."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createFuncOpToLLVMPass()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::scf::SCFDialect", - "mlir::LLVM::LLVMDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def MemoryOpToLLVM : Pass<"triton-cpu-memory-op-to-llvm", "mlir::ModuleOp"> { - let summary = "Convert Triton memory operations to LLVM for CPU."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createMemoryOpToLLVMPass()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::scf::SCFDialect", - "mlir::memref::MemRefDialect", - "mlir::LLVM::LLVMDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def GetProgramIdOpToLLVM : Pass<"triton-cpu-get-program-id-op-to-llvm", "mlir::ModuleOp"> { - let summary = "Convert Triton GetProgramId to LLVM for CPU."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createGetProgramIdOpToLLVMPass()"; - - let dependentDialects = ["mlir::LLVM::LLVMDialect", - "mlir::triton::TritonDialect"]; -} - -#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt deleted file mode 100644 index 56e231273ed6..000000000000 --- a/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonCPU) -add_public_tablegen_target(TritonToTritonCPUPassIncGen) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h deleted file mode 100644 index ab98a8741a16..000000000000 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES_H -#define TRITONTOTRITONCPU_CONVERSION_PASSES_H - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include - -namespace mlir { - -class ModuleOp; -template class OperationPass; - -namespace triton { -namespace cpu { - -#define GEN_PASS_DECL -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" - -std::unique_ptr> createConvertElementwiseOps(); -std::unique_ptr> createConvertMemoryOps(); -std::unique_ptr> createConvertPtrOps(); -std::unique_ptr> createConvertDotOp(); - -void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); -void registerTritonToTritonCPUPipeline(); - -#define GEN_PASS_REGISTRATION -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" - -} // namespace cpu -} // namespace triton - -} // namespace mlir - -#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td deleted file mode 100644 index 77e6528c6943..000000000000 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES -#define TRITONTOTRITONCPU_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - -def ConvertMemoryOps : Pass<"triton-cpu-convert-memory-ops", "mlir::ModuleOp"> { - let summary = "Convert Triton memory ops."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createConvertMemoryOps()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::memref::MemRefDialect", - "mlir::vector::VectorDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def ConvertElementwiseOps : Pass<"triton-cpu-convert-elementwise-ops", "mlir::ModuleOp"> { - let summary = "Convert elementwise ops."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createConvertElementwiseOps()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::memref::MemRefDialect", - "mlir::vector::VectorDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def ConvertPtrOps : Pass<"triton-cpu-convert-ptr-ops", "mlir::ModuleOp"> { - let summary = "Convert Triton ops related to pointer arithmetics."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createConvertPtrOps()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::memref::MemRefDialect", - "mlir::vector::VectorDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -def ConvertDotOp : Pass<"triton-cpu-convert-dot-op", "mlir::ModuleOp"> { - let summary = "Convert Triton DotOp."; - let description = [{ - - }]; - let constructor = "mlir::triton::cpu::createConvertDotOp()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::memref::MemRefDialect", - "mlir::vector::VectorDialect", - "mlir::triton::TritonDialect", - "mlir::triton::cpu::TritonCPUDialect"]; -} - -#endif diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt deleted file mode 100644 index fc9a19e52b0d..000000000000 --- a/third_party/cpu/lib/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(TritonCPUToLLVM) -add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt deleted file mode 100644 index 884c9352ef1b..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_triton_library(TritonCPUToLLVM - FuncOpToLLVM.cpp - GetProgramIdOpToLLVM.cpp - MemoryOpToLLVM.cpp - Pipeline.cpp - TypeConverter.cpp - - DEPENDS - TritonCPUToLLVMConversionPassIncGen - - LINK_LIBS PUBLIC - MLIRVectorToLLVMPass -) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp deleted file mode 100644 index 5895341fc34b..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Pass/Pass.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_FUNCOPTOLLVM -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -namespace mlir { -FailureOr -convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, - ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &converter); -} - -using namespace mlir; -using namespace mlir::triton; - -namespace { - -class TritonLLVMConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalOp(); - } -}; - -struct FuncOpConversion : public ConvertOpToLLVMPattern { - FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} - - /// Only retain those attributes that are not constructed by - /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument - /// attributes. - static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, - SmallVectorImpl &result) { - - for (const auto &attr : op->getAttrs()) { - if (attr.getName() == SymbolTable::getSymbolAttrName() || - attr.getName() == op.getFunctionTypeAttrName() || - attr.getName() == "std.varargs" || - (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) - continue; - result.push_back(attr); - } - } - - triton::FuncOp amendProgramIdArgs(triton::FuncOp funcOp, - ConversionPatternRewriter &rewriter) const { - // Push back a variable that indicates the current stack pointer of shared - // memory to the function arguments. - auto loc = funcOp.getLoc(); - auto ctx = funcOp->getContext(); - // 1. Modify the function type to add new arguments. - auto funcTy = funcOp.getFunctionType(); - auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); - amendedInputTy.push_back(i32_ty); - amendedInputTy.push_back(i32_ty); - amendedInputTy.push_back(i32_ty); - auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, - funcTy.getResults()); - // 2. Modify the argument attributes to add new arguments. - SmallVector amendedAttrs; - filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); - SmallVector amendedArgAttrs; - if (funcOp.getAllArgAttrs()) - amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedAttrs.push_back(rewriter.getNamedAttr( - funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); - // 3. Add a new arguments to the region - auto amendedFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); - auto ®ion = funcOp.getBody(); - region.addArgument(funcTy, loc); - rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), - amendedFuncOp.end()); - return amendedFuncOp; - } - - LogicalResult - matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Prevent LLVM's inliner to inline this function - auto modifiedFuncOp = funcOp; - if (LLVM::isKernel(funcOp)) - modifiedFuncOp = amendProgramIdArgs(modifiedFuncOp, rewriter); - - LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( - modifiedFuncOp, rewriter, *getTypeConverter()); - if (!newFuncOp) - return failure(); - - // required by AxisInfoAnalysis - if (LLVM::isKernel(funcOp)) - rewriter.eraseOp(modifiedFuncOp); - rewriter.eraseOp(funcOp); - return success(); - } -}; - -struct ReturnOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - LLVM::ReturnOp newOp; - if (adaptor.getOperands().size() < 2) { - // Single or no return value. - newOp = - rewriter.create(op.getLoc(), adaptor.getOperands()); - } else { - // Pack the results into a struct. - auto funcOp = op->getParentOfType(); - auto packedResultsTy = this->getTypeConverter()->packFunctionResults( - funcOp.getResultTypes()); - Value packedResults = - rewriter.create(op.getLoc(), packedResultsTy); - auto loc = op.getLoc(); - for (auto it : llvm::enumerate(adaptor.getOperands())) { - packedResults = - insert_val(packedResultsTy, packedResults, it.value(), it.index()); - } - newOp = rewriter.create(op.getLoc(), packedResults); - } - newOp->setAttrs(op->getAttrs()); - rewriter.replaceOp(op, newOp->getResults()); - return success(); - } -}; - -// CallOpInterfaceLowering is adapted from -// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 -struct CallOpConversion : public ConvertOpToLLVMPattern { - CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} - - LogicalResult - matchAndRewrite(triton::CallOp callOp, - typename triton::CallOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); - auto newCallOp = - convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); - if (!newCallOp) - return failure(); - auto results = getCallOpResults(callOp, newCallOp, rewriter); - rewriter.replaceOp(callOp, results); - return success(); - } - -private: - SmallVector - promoteOperands(triton::CallOp callOp, - typename triton::CallOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = callOp.getLoc(); - auto caller = callOp->getParentOfType(); - auto promotedOperands = this->getTypeConverter()->promoteOperands( - callOp.getLoc(), /*opOperands=*/callOp->getOperands(), - adaptor.getOperands(), rewriter); - return promotedOperands; - } - - LLVM::CallOp - convertCallOpToLLVMCallOp(triton::CallOp callOp, - ArrayRef promotedOperands, - ConversionPatternRewriter &rewriter) const { - // Pack the result types into a struct. - Type packedResult = nullptr; - unsigned numResults = callOp.getNumResults(); - auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); - - if (numResults != 0) { - if (!(packedResult = - this->getTypeConverter()->packFunctionResults(resultTypes))) - return nullptr; - } - auto newCallOp = rewriter.create( - callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), - promotedOperands, callOp->getAttrs()); - return newCallOp; - } - - SmallVector - getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, - ConversionPatternRewriter &rewriter) const { - auto numResults = callOp.getNumResults(); - SmallVector results; - if (numResults < 2) { - // If < 2 results, packing did not do anything and we can just return. - results.append(newCallOp.result_begin(), newCallOp.result_end()); - } else { - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - results.push_back(rewriter.create( - callOp.getLoc(), newCallOp->getResult(0), i)); - } - } - return results; - } -}; - -struct FuncOpToLLVM : public triton::impl::FuncOpToLLVMBase { - using FuncOpToLLVMBase::FuncOpToLLVMBase; - - FuncOpToLLVM() : FuncOpToLLVMBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - mlir::LowerToLLVMOptions option(context); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget convTarget(*context); - - // Lower tt.func - RewritePatternSet funcPatterns(context); - funcPatterns.add(typeConverter, - /*benefit=*/1); - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, - funcPatterns); - if (failed( - applyPartialConversion(mod, convTarget, std::move(funcPatterns)))) - return signalPassFailure(); - - // Lower tt.call, tt.return - int benefit = 10; - RewritePatternSet patterns(context); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createFuncOpToLLVMPass() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp deleted file mode 100644 index 4c593f1ff7aa..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_GETPROGRAMIDOPTOLLVM -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -class TritonLLVMConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalOp(); - } -}; - -// TODO: use enums to access struct fields. -struct GetProgramIdOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(GetProgramIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto funcOp = op->getParentOfType(); - assert(funcOp && "expected LLVM::FuncOp as a parent of GetProgramIdOp"); - auto args = funcOp.getArguments(); - // Last three args are x, y, z program ids. - auto argIdx = args.size() - 3 + op.getAxisAsInt(); - assert(argIdx < args.size() && "out-of-bounds arg index"); - assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); - rewriter.replaceOp(op, args[argIdx]); - return success(); - } -}; - -struct GetProgramIdOpToLLVM - : public triton::impl::GetProgramIdOpToLLVMBase { - using GetProgramIdOpToLLVMBase::GetProgramIdOpToLLVMBase; - - GetProgramIdOpToLLVM() : GetProgramIdOpToLLVMBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - mlir::LowerToLLVMOptions option(context); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget convTarget(*context); - - RewritePatternSet patterns(context); - patterns.add(typeConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createGetProgramIdOpToLLVMPass() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp deleted file mode 100644 index 594495c4ab9d..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp +++ /dev/null @@ -1,277 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_MEMORYOPTOLLVM -#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -class TritonLLVMConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalOp(); - } -}; - -// TODO: use enums to access struct fields. -struct ExtractMemRefOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ExtractMemRefOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); - auto memRefTy = cast(op.getType()); - auto rank = memRefTy.getRank(); - auto memRefStructTy = getTypeConverter()->convertType(op.getType()); - auto memRefStructFields = - cast(memRefStructTy).getBody(); - auto i64Ty = IntegerType::get(getContext(), 64); - - auto copyValue = [&](Value to, int64_t idxFrom, int64_t idxTo) { - auto valueTy = memRefStructFields[idxTo]; - Value val = rewriter.create( - loc, valueTy, tensorPtrStruct, idxFrom); - return rewriter.create(loc, memRefStructTy, to, val, - idxTo); - }; - - Value res = undef(memRefStructTy); - // Copy base. - res = copyValue(res, 0, 1); - // Use 0 offset. - res = rewriter.create(loc, memRefStructTy, res, - i64_val(0), 2); - // Copy shape. - res = copyValue(res, 2, 3); - // Copy strides. - res = copyValue(res, 3, 4); - - rewriter.replaceOp(op, res); - - return success(); - } -}; - -struct ExtractIndicesOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ExtractIndicesOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - - auto loc = op.getLoc(); - Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); - auto rank = op.getNumResults(); - auto i64Ty = IntegerType::get(getContext(), 64); - SmallVector indices; - - for (int64_t i = 0; i < rank; i++) { - Value offs = rewriter.create( - loc, i64Ty, tensorPtrStruct, SmallVector{1, i}); - Value stride = rewriter.create( - loc, i64Ty, tensorPtrStruct, SmallVector{3, i}); - indices.push_back(rewriter.create(loc, offs, stride)); - } - - rewriter.replaceOp(op, indices); - - return success(); - } -}; - -struct MakeTensorPtrOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto structTy = getTypeConverter()->convertType(op.getType()); - auto i64Ty = IntegerType::get(getContext(), 64); - - auto insertArray = [&](Value structVal, auto values, int64_t idx, - Type zextTo = nullptr) { - for (int64_t i = 0; i < static_cast(values.size()); ++i) { - Value val = values[i]; - if (zextTo) - val = rewriter.create(loc, zextTo, val); - structVal = rewriter.create( - loc, structTy, structVal, val, SmallVector{idx, i}); - } - return structVal; - }; - - Value res = undef(structTy); - // 0 - base pointer. - auto base = rewriter.getRemappedValue(op.getBase()); - res = rewriter.create(loc, structTy, res, base, 0); - // 1 - array for offsets. Promote values to i64. - res = insertArray(res, op.getOffsets(), 1, i64Ty); - // 2 - array for shape. - res = insertArray(res, op.getShape(), 2); - // 3 - array for strides. - res = insertArray(res, op.getStrides(), 3); - - rewriter.replaceOp(op, res); - - return success(); - } -}; - -struct AdvanceOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto i64Ty = IntegerType::get(getContext(), 64); - Value res = rewriter.getRemappedValue(op.getPtr()); - Type structTy = res.getType(); - auto offsets = op.getOffsets(); - - for (int64_t i = 0; i < offsets.size(); ++i) { - auto oldOffset = rewriter.create( - loc, i64Ty, res, SmallVector{1, i}); - auto step = rewriter.create(loc, i64Ty, offsets[i]); - auto newOffset = rewriter.create(loc, oldOffset, step); - res = rewriter.create(loc, structTy, res, newOffset, - SmallVector{1, i}); - } - - rewriter.replaceOp(op, res); - - return success(); - } -}; - -struct LoadOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Type ptrTy = LLVM::LLVMPointerType::get(getContext()); - Value ptr = rewriter.getRemappedValue(op.getPtr()); - Type resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, ptr, 0, - op.getIsVolatile()); - return success(); - } -}; - -struct StoreOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value ptr = rewriter.getRemappedValue(op.getPtr()); - Value val = rewriter.getRemappedValue(op.getValue()); - rewriter.replaceOpWithNewOp(op, val, ptr); - return success(); - } -}; - -struct PtrToIntOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = rewriter.getRemappedValue(op.getSrc()); - Type resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, src); - return success(); - } -}; - -struct IntToPtrOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value src = rewriter.getRemappedValue(op.getSrc()); - Type resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, src); - return success(); - } -}; - -struct MemoryOpToLLVM - : public triton::impl::MemoryOpToLLVMBase { - using MemoryOpToLLVMBase::MemoryOpToLLVMBase; - - MemoryOpToLLVM() : MemoryOpToLLVMBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - mlir::LowerToLLVMOptions option(context); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget convTarget(*context); - - RewritePatternSet patterns(context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createMemoryOpToLLVMPass() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp deleted file mode 100644 index 914f56e668f8..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Conversion/Passes.h" -#include "mlir/Pass/PassManager.h" - -namespace mlir { -namespace triton { -namespace cpu { - -void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) { - pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); - pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); - pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); - // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); -} - -void registerTritonCPUToLLVMPipeline() { - PassPipelineRegistration<>("triton-cpu-to-llvmir", - "TritonCPU to LLVM conversion pipeline.", - tritonCPUToLLVMPipelineBuilder); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp deleted file mode 100644 index 144cb57b1115..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "TypeConverter.h" - -using namespace mlir; -using namespace mlir::triton; - -TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( - MLIRContext *ctx, LowerToLLVMOptions &option, - const DataLayoutAnalysis *analysis) - : LLVMTypeConverter(ctx, option, analysis) { - addConversion([&](triton::PointerType type) -> std::optional { - return convertTritonPointerType(type); - }); - addConversion([this](RankedTensorType tensorTy) -> std::optional { - if (isa(tensorTy.getElementType())) - return VectorType::get(tensorTy.getShape(), - IntegerType::get(tensorTy.getContext(), 64)); - return std::nullopt; - }); -} - -Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( - triton::PointerType type) { - auto ctx = type.getContext(); - auto pointeeType = type.getPointeeType(); - if (isa(pointeeType)) { - // struct { - // ptr base_ptr; - // array offsets; - // array shape; - // array strides; - // } - auto tensorTy = cast(pointeeType); - auto rank = tensorTy.getShape().size(); - auto i64Ty = IntegerType::get(ctx, 64); - SmallVector types; - types.push_back(LLVM::LLVMPointerType::get(ctx)); - types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); - types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); - types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); - return LLVM::LLVMStructType::getLiteral(ctx, types); - } - return LLVM::LLVMPointerType::get(ctx); -} diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h deleted file mode 100644 index 35d74a9ec430..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H -#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Conversion/MLIRTypes.h" -#include "triton/Dialect/Triton/IR/Types.h" - -using namespace mlir; -using namespace mlir::triton; - -class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { -public: - using TypeConverter::convertType; - - TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, - const DataLayoutAnalysis *analysis = nullptr); - - Type convertTritonPointerType(triton::PointerType type); -}; - -#endif diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt deleted file mode 100644 index 9fa892b449ac..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_triton_library(TritonToTritonCPU - ConvertDotOp.cpp - ConvertElementwiseOps.cpp - ConvertMemoryOps.cpp - ConvertPtrOps.cpp - Pipeline.cpp - TypeConverter.cpp - - DEPENDS - TritonToTritonCPUPassIncGen - - LINK_LIBS PUBLIC - TritonCPUIR - MLIRVectorDialect -) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp deleted file mode 100644 index b6fbb1893202..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTDOTOP -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -class PtrConversionTarget : public ConversionTarget { -public: - explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - - addIllegalOp(); - } -}; - -struct DotOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *ctx = op.getContext(); - Value a = rewriter.getRemappedValue(op.getA()); - Value b = rewriter.getRemappedValue(op.getB()); - Value c = rewriter.getRemappedValue(op.getC()); - auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); - auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); - auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); - auto iteratorTypes = rewriter.getArrayAttr( - {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, vector::IteratorType::reduction)}); - rewriter.replaceOpWithNewOp( - op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), - iteratorTypes); - return success(); - } -}; - -struct ConvertDotOp : public triton::impl::ConvertDotOpBase { - using ConvertDotOpBase::ConvertDotOpBase; - - ConvertDotOp() : ConvertDotOpBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - TritonToTritonCPUTypeConverter typeConverter; - PtrConversionTarget convTarget(*context, typeConverter); - RewritePatternSet patterns(context); - patterns.add(typeConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createConvertDotOp() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp deleted file mode 100644 index 70e8c4ed3c66..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ /dev/null @@ -1,300 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTELEMENTWISEOPS -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -class ElementwiseOpConversionTarget : public ConversionTarget { -public: - explicit ElementwiseOpConversionTarget(MLIRContext &ctx, - TypeConverter &converter) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - - addDynamicallyLegalDialect( - [&](Operation *op) -> std::optional { - return converter.isLegal(op); - }); - addDynamicallyLegalDialect( - [&](Operation *op) -> std::optional { - return converter.isLegal(op); - }); - - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - } -}; - -template -struct ElementwiseOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - using OpConversionPattern::getTypeConverter; - using typename OpConversionPattern::OpAdaptor; - - LogicalResult - matchAndRewrite(OpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - OperationState newState(op.getLoc(), ResOpT::getOperationName()); - // Convert operands. - for (auto operand : op->getOperands()) { - Value newOperand = rewriter.getRemappedValue(operand); - newState.operands.push_back(newOperand); - } - // Convert result types. - if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), - newState.types))) { - return failure(); - } - newState.attributes = op->getAttrs(); - - auto newOp = rewriter.create(newState); - rewriter.replaceOp(op, newOp); - - return success(); - } -}; - -template <> -struct ElementwiseOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(isa(op.getType())); - auto resTy = - dyn_cast(getTypeConverter()->convertType(op.getType())); - assert(resTy); - if (auto denseAttr = dyn_cast(op.getValueAttr())) { - rewriter.replaceOpWithNewOp(op, resTy, - denseAttr.reshape(resTy)); - } else { - llvm_unreachable("Unexpected constant attribute"); - } - return success(); - } -}; - -template <> -struct ElementwiseOpConversion - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(isa(op.getType())); - auto loc = op.getLoc(); - auto src = rewriter.getRemappedValue(op.getSrc()); - auto srcShape = dyn_cast(src.getType()).getShape(); - auto resTy = - dyn_cast(getTypeConverter()->convertType(op.getType())); - auto dstShape = resTy.getShape(); - auto elemTy = resTy.getElementType(); - - // There are restrictions on how shape can be modified by ShapeCastOp - // when rank is changed. For now, we simply detect it and handle through - // a cast to 1D vector. Better solution may be required later. - if (canCastShape(srcShape, dstShape)) { - rewriter.replaceOpWithNewOp( - op, VectorType::get(dstShape, elemTy), src); - } else { - SmallVector tmpShape({resTy.getNumElements()}); - auto tmp = rewriter.create( - loc, VectorType::get(tmpShape, elemTy), src); - rewriter.replaceOpWithNewOp( - op, VectorType::get(dstShape, elemTy), tmp); - } - return success(); - } - -private: - bool canCastShape(ArrayRef src, ArrayRef dst) const { - if (src.size() == dst.size()) - return true; - if (src.size() > dst.size()) - return canCastShape(dst, src); - - size_t srcIdx = 0; - size_t dstIdx = 0; - while (srcIdx < src.size() && dstIdx < dst.size()) { - if (src[srcIdx] == 1) { - ++srcIdx; - } else { - // Source dim size should be a product of continuous dest dim sizes. - int64_t srcSize = src[srcIdx++]; - int64_t dstSize = dst[dstIdx++]; - while (dstSize < srcSize && dstIdx < dst.size()) - dstSize *= dst[dstIdx++]; - if (dstSize != srcSize) - return false; - } - } - - // Skip trailing 1s. - while (srcIdx < src.size() && src[srcIdx] == 1) - ++srcIdx; - while (dstIdx < dst.size() && dst[dstIdx] == 1) - ++dstIdx; - - return srcIdx == src.size() && dstIdx == dst.size(); - } -}; - -struct ConvertElementwiseOps - : public triton::impl::ConvertElementwiseOpsBase { - using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; - - ConvertElementwiseOps() : ConvertElementwiseOpsBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - TritonToTritonCPUTypeConverter typeConverter; - ElementwiseOpConversionTarget convTarget(*context, typeConverter); - RewritePatternSet patterns(context); - - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, - context); - - patterns.add>(typeConverter, - context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - patterns.add>(typeConverter, context); - - patterns.add>( - typeConverter, context); - patterns - .add>( - typeConverter, context); - patterns.add< - ElementwiseOpConversion>( - typeConverter, context); - patterns.add>( - typeConverter, context); - patterns.add>( - typeConverter, context); - patterns.add>(typeConverter, - context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createConvertElementwiseOps() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp deleted file mode 100644 index 1679ecc7af90..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ /dev/null @@ -1,277 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTMEMORYOPS -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -struct LoadOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = loadOp.getLoc(); - auto mask = loadOp.getMask(); - auto ptr = loadOp.getPtr(); - auto boundaryChecks = loadOp.getBoundaryCheck(); - - if (!triton::isTensorPointerType(ptr.getType())) { - return lowerToScalarLoads(loadOp, rewriter); - } - - // TODO: support masks. - if (mask) { - llvm_unreachable("unsupported load op"); - } - - auto memRef = rewriter.getRemappedValue(ptr); - auto rank = dyn_cast(memRef.getType()).getRank(); - auto resTy = dyn_cast( - getTypeConverter()->convertType(loadOp.getResult().getType())); - auto indices = rewriter.create(loc, ptr).getResults(); - SmallVector inBounds(rank, true); - for (auto dim : boundaryChecks) { - inBounds[dim] = false; - } - auto vecRead = rewriter.create(loc, resTy, memRef, - indices, inBounds); - rewriter.replaceOp(loadOp, vecRead); - return success(); - } - - LogicalResult lowerToScalarLoads(triton::LoadOp loadOp, - ConversionPatternRewriter &rewriter) const { - // Scalar loads and boundary checks are not expected. - assert(loadOp.getBoundaryCheck().empty()); - assert(isa(loadOp.getType())); - - auto loc = loadOp.getLoc(); - auto vecTy = - dyn_cast(getTypeConverter()->convertType(loadOp.getType())); - auto ptrs = rewriter.getRemappedValue(loadOp.getPtr()); - auto mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) - : nullptr; - auto ptrTy = - dyn_cast(loadOp.getPtr().getType()).getElementType(); - auto cache = loadOp.getCache(); - auto evict = loadOp.getEvict(); - auto isVolatile = loadOp.getIsVolatile(); - - Value defaultVal = loadOp.getOther(); - if (!defaultVal) - defaultVal = rewriter.create( - loc, rewriter.getZeroAttr(vecTy.getElementType())); - Value dst = rewriter.create(loc, vecTy, defaultVal); - - int64_t numElems = vecTy.getNumElements(); - auto strides = computeStrides(vecTy.getShape()); - for (auto idx = 0; idx < numElems; ++idx) { - auto indices = delinearize(idx, strides); - Block *headerBlock = rewriter.getBlock(); - Block *condBlock = nullptr; - Value origDst = dst; - // Create a conditional block for load if there is a mask. - if (mask) { - condBlock = - rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToStart(condBlock); - } - - Value ptr = rewriter.create(loc, ptrs, indices); - ptr = rewriter.create(loc, ptrTy, ptr); - Value val = - rewriter.create(loc, ptr, cache, evict, isVolatile); - dst = rewriter.create(loc, val, dst, indices); - - // Add predicate and branches. - if (mask) { - Block *footerBlock = - rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); - Value resDst = dst; - dst = footerBlock->addArgument(dst.getType(), dst.getLoc()); - rewriter.setInsertionPointToEnd(headerBlock); - auto predicate = rewriter.create(loc, mask, indices); - rewriter.create(loc, predicate, condBlock, - footerBlock, origDst); - rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, footerBlock, resDst); - rewriter.setInsertionPointToStart(footerBlock); - } - } - - rewriter.replaceOp(loadOp, dst); - - return success(); - } -}; - -struct StoreOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = storeOp.getLoc(); - auto mask = storeOp.getMask(); - auto ptr = storeOp.getPtr(); - auto boundaryChecks = storeOp.getBoundaryCheck(); - - if (!triton::isTensorPointerType(ptr.getType())) { - return lowerToScalarStores(storeOp, rewriter); - } - - // TODO: support masks. - if (mask) { - llvm_unreachable("unsupported store op"); - } - - auto value = rewriter.getRemappedValue(storeOp.getValue()); - auto memRef = rewriter.getRemappedValue(ptr); - auto rank = dyn_cast(memRef.getType()).getRank(); - auto indices = rewriter.create(loc, ptr).getResults(); - SmallVector inBounds(rank, true); - for (auto dim : boundaryChecks) { - inBounds[dim] = false; - } - auto vecWrite = rewriter.create(loc, value, memRef, - indices, inBounds); - rewriter.replaceOp(storeOp, vecWrite); - return success(); - } - - LogicalResult lowerToScalarStores(triton::StoreOp storeOp, - ConversionPatternRewriter &rewriter) const { - // Scalar stores and boundary checks are not expected. - assert(storeOp.getBoundaryCheck().empty()); - assert(isa(storeOp.getValue().getType())); - - auto loc = storeOp.getLoc(); - auto ptrs = rewriter.getRemappedValue(storeOp.getPtr()); - auto mask = storeOp.getMask() ? rewriter.getRemappedValue(storeOp.getMask()) - : nullptr; - auto vals = rewriter.getRemappedValue(storeOp.getValue()); - auto tensorTy = dyn_cast(storeOp.getPtr().getType()); - auto ptrTy = tensorTy.getElementType(); - auto cache = storeOp.getCache(); - auto evict = storeOp.getEvict(); - - int64_t numElems = tensorTy.getNumElements(); - auto strides = computeStrides(tensorTy.getShape()); - for (auto idx = 0; idx < numElems; ++idx) { - auto indices = delinearize(idx, strides); - Block *headerBlock = rewriter.getBlock(); - Block *condBlock = nullptr; - // Create a conditional block for store if there is a mask. - if (mask) { - condBlock = - rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToStart(condBlock); - } - - Value ptr = rewriter.create(loc, ptrs, indices); - ptr = rewriter.create(loc, ptrTy, ptr); - Value val = rewriter.create(loc, vals, indices); - rewriter.create(loc, ptr, val, cache, evict); - - // Add predicate and branches. - if (mask) { - Block *footerBlock = - rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToEnd(headerBlock); - auto predicate = rewriter.create(loc, mask, indices); - rewriter.create(loc, predicate, condBlock, - footerBlock); - rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, footerBlock); - rewriter.setInsertionPointToStart(footerBlock); - } - } - - rewriter.eraseOp(storeOp); - - return success(); - } -}; - -class MemoryOpConversionTarget : public ConversionTarget { -public: - explicit MemoryOpConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - - // Allow only scalar loads and stores. - addDynamicallyLegalOp([](triton::LoadOp loadOp) { - return loadOp.getType().isIntOrIndexOrFloat(); - }); - addDynamicallyLegalOp([](triton::StoreOp storeOp) { - return storeOp.getValue().getType().isIntOrIndexOrFloat(); - }); - } -}; - -struct ConvertMemoryOps - : public triton::impl::ConvertMemoryOpsBase { - using ConvertMemoryOpsBase::ConvertMemoryOpsBase; - - ConvertMemoryOps() : ConvertMemoryOpsBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - MemoryOpConversionTarget convTarget(*context); - TritonToTritonCPUTypeConverter pointerConverter; - RewritePatternSet patterns(context); - patterns.add(pointerConverter, context); - patterns.add(pointerConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createConvertMemoryOps() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp deleted file mode 100644 index ade8b858bbfb..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp +++ /dev/null @@ -1,191 +0,0 @@ -#include "TypeConverter.h" - -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTPTROPS -#include "cpu/include/TritonToTritonCPU/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -namespace { - -unsigned getElemBitWidth(Type type) { - if (auto tensorTy = dyn_cast(type)) - return tensorTy.getElementType().getIntOrFloatBitWidth(); - if (auto vectorTy = dyn_cast(type)) - return vectorTy.getElementType().getIntOrFloatBitWidth(); - return type.getIntOrFloatBitWidth(); -} - -class PtrConversionTarget : public ConversionTarget { -public: - explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - - // Allow only scalar pointer conversion. - addDynamicallyLegalOp( - [](triton::PtrToIntOp op) { return op.getType().isInteger(); }); - addDynamicallyLegalOp([](triton::IntToPtrOp op) { - return op.getSrc().getType().isInteger(); - }); - } -}; - -struct MakeRangeOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - int32_t start = static_cast(op.getStart()); - int32_t end = static_cast(op.getEnd()); - assert(end >= start); - - llvm::SmallVector values; - values.reserve(end - start); - for (int32_t v = start; v < end; ++v) { - values.push_back(v); - } - - Type resTy = getTypeConverter()->convertType(op.getType()); - auto newOp = rewriter.create( - op.getLoc(), resTy, rewriter.getI32VectorAttr(values)); - - rewriter.replaceOp(op, newOp); - return success(); - } -}; - -struct SplatOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value val = op.getSrc(); - Type dstValType = getTypeConverter()->convertType(val.getType()); - // Cast pointer - if (isa(val.getType())) - val = rewriter - .create( - loc, getTypeConverter()->convertType(val.getType()), val) - .getResult(); - Type resType = getTypeConverter()->convertType(op.getType()); - auto cast = rewriter.create(loc, resType, val); - - rewriter.replaceOp(op, cast); - return success(); - } -}; - -struct AddPtrOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - Value ptr = rewriter.getRemappedValue(op.getPtr()); - Value offset = rewriter.getRemappedValue(op.getOffset()); - unsigned offsetBitWidth = getElemBitWidth(offset.getType()); - unsigned elemBitWidth = getPointeeBitWidth(op.getPtr().getType()); - // Compute scale. i1 elements take 1 byte. - Value scale = rewriter.create( - loc, (elemBitWidth + 7) / 8, offsetBitWidth); - if (isa(offset.getType())) - scale = rewriter.create(loc, offset.getType(), scale); - offset = rewriter.create(loc, offset, scale); - offset = rewriter.create(loc, ptr.getType(), offset); - rewriter.replaceOpWithNewOp(op, ptr.getType(), ptr, offset); - return success(); - } -}; - -struct PtrToIntOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value val = rewriter.getRemappedValue(op.getSrc()); - auto resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, val); - return success(); - } -}; - -struct IntToPtrOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value val = rewriter.getRemappedValue(op.getSrc()); - auto resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, val); - return success(); - } -}; - -struct ConvertPtrOps : public triton::impl::ConvertPtrOpsBase { - using ConvertPtrOpsBase::ConvertPtrOpsBase; - - ConvertPtrOps() : ConvertPtrOpsBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - - TritonToTritonCPUTypeConverter typeConverter; - PtrConversionTarget convTarget(*context, typeConverter); - RewritePatternSet patterns(context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // namespace - -namespace mlir { -namespace triton { -namespace cpu { - -std::unique_ptr> createConvertPtrOps() { - return std::make_unique(); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp deleted file mode 100644 index 16bff114ed81..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Conversion/Passes.h" -#include "mlir/Pass/PassManager.h" - -namespace mlir { -namespace triton { -namespace cpu { - -void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); - pm.addPass(mlir::triton::cpu::createConvertPtrOps()); - pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); - pm.addPass(mlir::triton::cpu::createConvertDotOp()); - // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); -} - -void registerTritonToTritonCPUPipeline() { - PassPipelineRegistration<>("triton-to-triton-cpu", - "Triton to TritonCPU conversion pipeline.", - tritonToTritonCPUPipelineBuilder); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp deleted file mode 100644 index 07b2da0468ba..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "TypeConverter.h" - -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { - addConversion([](Type type) { return type; }); - addConversion([](triton::PointerType ptrTy) -> Type { - if (triton::isTensorPointerType(ptrTy)) { - // Tensor pointer is translated into a memref - auto tensorTy = dyn_cast(ptrTy.getPointeeType()); - auto elemTy = tensorTy.getElementType(); - // TODO: use dynamic strides - SmallVector shape(tensorTy.getRank(), ShapedType::kDynamic); - return MemRefType::get(shape, elemTy); - } - return IntegerType::get(ptrTy.getContext(), 64); - }); - addConversion([this](RankedTensorType tensorTy) -> Type { - Type elemTy = convertType(tensorTy.getElementType()); - return VectorType::get(tensorTy.getShape(), elemTy); - }); - - // Converted ops produce vectors instead of tensors. Provide conversion - // here for users. Also, convert pointers when required. - addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) -> std::optional { - if (isa(type)) - return builder.create(loc, type, inputs); - return builder.create(loc, type, inputs) - .getResult(0); - }); - - // Converted loads and stores consume memrefs instead of pointers, use extract - // op to get them. Also, provide conversion for vector users and pointer - // casts. - addTargetMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) -> std::optional { - if (type.isInteger() && isa(inputs.front().getType())) - return builder.create(loc, type, inputs); - if (isa(type)) - return builder.create(loc, type, inputs) - .getResult(0); - if (isa(type)) - return builder.create(loc, type, inputs); - llvm_unreachable("Unexpected target materizalization"); - }); -} diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h deleted file mode 100644 index cb89f0886c60..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H -#define TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H - -#include "triton/Conversion/MLIRTypes.h" -#include "triton/Dialect/Triton/IR/Types.h" - -using namespace mlir; -using namespace mlir::triton; - -class TritonToTritonCPUTypeConverter : public TypeConverter { -public: - using TypeConverter::convertType; - - TritonToTritonCPUTypeConverter(); - - Type convertTritonPointerType(triton::PointerType type); -}; - -#endif diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index efc949d6f4a1..302951d04d59 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -1,20 +1,9 @@ -#include "TritonCPUToLLVM/Passes.h" -#include "TritonToTritonCPU/Passes.h" - -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/Passes.h" -#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "triton/Conversion/TritonCPUToLLVM/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "llvm/IR/Constants.h" #include "llvm/Support/TargetSelect.h" - #include #include #include @@ -25,26 +14,8 @@ namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { using namespace mlir::triton; - // m.def("add_to_llvmir", [](mlir::PassManager &pm) { - // pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); - // }); - m.def("add_triton_to_triton_cpu_pipeline", [](mlir::PassManager &pm) { - mlir::triton::cpu::tritonToTritonCPUPipelineBuilder(pm); - }); - m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { - mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); - }); - m.def("add_vector_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::createConvertVectorToLLVMPass()); - }); - m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); - }); - m.def("add_math_to_libm", [](mlir::PassManager &pm) { - pm.addPass(mlir::createConvertMathToLibmPass()); - }); - m.def("add_func_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::createConvertFuncToLLVMPass()); + m.def("add_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); }); } @@ -54,18 +25,8 @@ void init_triton_cpu(py::module &&m) { m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); - - m.def("find_kernel_names", [](mlir::ModuleOp &mod) { - std::vector res; - mod.walk([&](mlir::FunctionOpInterface funcOp) { - if (funcOp.getVisibility() == mlir::SymbolTable::Visibility::Public) - res.push_back(funcOp.getName().str()); - }); - return res; - }); } From 31ad8a1a8d591bcd331bed6054bb672b15e6e1f8 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 17 May 2024 16:58:29 -0500 Subject: [PATCH 011/165] Add a workaround for LLVM bug in codegen for bf16 vector cast. (#4) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 7 +++++++ python/triton/_internal_testing.py | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 582856774785..a66737989d07 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -36,6 +36,7 @@ is_hip_mi300, is_xpu, get_arch, + is_cpu, torch_float8_dtypes, torch_dtypes, numpy_random, @@ -1813,6 +1814,12 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + # bf16 vector cast is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/92471 + # TODO: Remove the change after the bug is fixed. + if is_cpu() and dtype_x == 'bfloat16' and size > 128: + size = 128 + torch.manual_seed(0) # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. if dtype_x.startswith('bfloat'): diff --git a/python/triton/_internal_testing.py b/python/triton/_internal_testing.py index dbc2d017930b..d337fa20c7c8 100644 --- a/python/triton/_internal_testing.py +++ b/python/triton/_internal_testing.py @@ -81,6 +81,11 @@ def get_arch(): return "" if target is None else str(target.arch) +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None): """ Override `rs` if you're calling this function twice and don't want the same From 9ec6fa81008b4466a8759aba5e647295f5b3a6e0 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 24 May 2024 13:03:59 -0500 Subject: [PATCH 012/165] Prototype of the Triton CPU backend with basic compilation and execution flows (#6) * Support basic lowering through vector dialect in CPU backend. Signed-off-by: Ilya Enkovich * Use axis info in memory op lowering. Signed-off-by: Ilya Enkovich * Mark test_ptx_cast as enabled for CPU. Signed-off-by: Ilya Enkovich * Support umulhi operation. Signed-off-by: Ilya Enkovich * Support tl.clamp, tl.minimum, tl.maximum. Signed-off-by: Ilya Enkovich * Add enable_fp_fusion opt for CPU (only affects ASM dump now). Signed-off-by: Ilya Enkovich * Fix kernel args passing for propagated constants. Signed-off-by: Ilya Enkovich * Add permutations support. Signed-off-by: Ilya Enkovich * Support 2-D transfer_read/transfer_write lowering. Signed-off-by: Ilya Enkovich * Introduce shape info analysis and use it for loads/stores by block pointers. Delay scalar pointers lowering. Signed-off-by: Ilya Enkovich * Support 'other' arg for loads. Signed-off-by: Ilya Enkovich * Support tl.join. Signed-off-by: Ilya Enkovich * Minor renaming. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- bin/RegisterTritonDialects.h | 8 + .../Conversion/TritonCPUToLLVM/CMakeLists.txt | 2 +- .../Dialect/TritonCPU/IR/TritonCPUAttrDefs.td | 5 +- .../Dialect/TritonCPU/IR/TritonCPUDialect.td | 3 + .../Dialect/TritonCPU/IR/TritonCPUOps.td | 67 +++ lib/Conversion/CMakeLists.txt | 4 +- lib/Dialect/TritonCPU/IR/Dialect.cpp | 38 +- python/src/llvm.cc | 68 +++ python/src/passes.cc | 5 +- python/test/unit/language/test_core.py | 76 ++- third_party/cpu/CMakeLists.txt | 5 + third_party/cpu/backend/compiler.py | 76 +-- third_party/cpu/backend/driver.cpp | 224 +++++++++ third_party/cpu/backend/driver.py | 327 ++++++++++++- .../cpu/include/Analysis/TensorPtrShapeInfo.h | 107 ++++ third_party/cpu/include/CMakeLists.txt | 2 + .../include/TritonCPUToLLVM/CMakeLists.txt | 3 + .../cpu/include/TritonCPUToLLVM/Passes.h | 36 ++ .../cpu/include/TritonCPUToLLVM/Passes.td | 46 ++ .../include/TritonToTritonCPU/CMakeLists.txt | 3 + .../cpu/include/TritonToTritonCPU/Passes.h | 38 ++ .../cpu/include/TritonToTritonCPU/Passes.td | 77 +++ third_party/cpu/lib/Analysis/CMakeLists.txt | 11 + .../cpu/lib/Analysis/TensorPtrShapeInfo.cpp | 219 +++++++++ third_party/cpu/lib/CMakeLists.txt | 3 + .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 13 + .../cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp | 278 +++++++++++ .../TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp | 98 ++++ .../lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp | 353 ++++++++++++++ .../cpu/lib/TritonCPUToLLVM/Pipeline.cpp | 25 + .../cpu/lib/TritonCPUToLLVM/TypeConverter.cpp | 43 ++ .../cpu/lib/TritonCPUToLLVM/TypeConverter.h | 22 + .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 16 + .../ConvertControlFlowOps.cpp | 121 +++++ .../lib/TritonToTritonCPU/ConvertDotOp.cpp | 102 ++++ .../ConvertElementwiseOps.cpp | 341 +++++++++++++ .../TritonToTritonCPU/ConvertMemoryOps.cpp | 456 ++++++++++++++++++ .../lib/TritonToTritonCPU/ConvertPtrOps.cpp | 195 ++++++++ .../lib/TritonToTritonCPU/OpTypeConversion.h | 37 ++ .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 27 ++ .../lib/TritonToTritonCPU/TypeConverter.cpp | 34 ++ .../cpu/lib/TritonToTritonCPU/TypeConverter.h | 19 + third_party/cpu/triton_cpu.cc | 56 ++- 43 files changed, 3617 insertions(+), 72 deletions(-) create mode 100644 third_party/cpu/backend/driver.cpp create mode 100644 third_party/cpu/include/Analysis/TensorPtrShapeInfo.h create mode 100644 third_party/cpu/include/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUToLLVM/Passes.h create mode 100644 third_party/cpu/include/TritonCPUToLLVM/Passes.td create mode 100644 third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonToTritonCPU/Passes.h create mode 100644 third_party/cpu/include/TritonToTritonCPU/Passes.td create mode 100644 third_party/cpu/lib/Analysis/CMakeLists.txt create mode 100644 third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp create mode 100644 third_party/cpu/lib/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h create mode 100644 third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h create mode 100644 third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 17737e1096c6..ca922e824793 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -17,6 +17,8 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "cpu/include/TritonToTritonCPU/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" @@ -69,6 +71,12 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints(); mlir::triton::registerTritonAMDGPULowerInstructionSchedHints(); + // CPU passes + mlir::triton::cpu::registerTritonToTritonCPUPasses(); + mlir::triton::cpu::registerTritonToTritonCPUPipeline(); + mlir::triton::cpu::registerTritonCPUToLLVMPasses(); + mlir::triton::cpu::registerTritonCPUToLLVMPipeline(); + // TODO: register Triton & TritonGPU passes registry.insert traits = [], string baseCppClass = "::mlir::Attribute"> : AttrDef { - let description = [{ - WIP... - }]; + let description = [{TritonCPU attr.}]; + let attrName = "triton.cpu." # attrMnemonic; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td index 9ccac13f0b58..260db2743046 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUDialect.td @@ -17,6 +17,7 @@ def TritonCPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", "tensor::TensorDialect", + "mlir::memref::MemRefDialect", ]; let extraClassDeclaration = [{ @@ -24,6 +25,8 @@ def TritonCPU_Dialect : Dialect { }]; let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; + let usePropertiesForAttributes = 1; } #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index 16d9e433e899..712826d02f91 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -7,6 +7,73 @@ include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" +include "triton/Dialect/Triton/IR/TritonInterfaces.td" include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/ViewLikeInterface.td" + +class TTC_Op traits = []> : + Op { +} + +def TTC_ExtractMemRefOp : TTC_Op<"extract_memref", [NoMemoryEffect]> { + let summary = "Extract base memref from a block pointer"; + + let description = [{ + Extract base memref from a block pointer. It covers whole base tensor memory, + not only the block referenced. Base pointer, shape, and strides are used + in the resulting memref. Offsets and block shape are ignored. + + }]; + + let arguments = (ins TT_TensorPtr:$src); + + let results = (outs AnyRankedOrUnrankedMemRef:$result); + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTC_ExtractIndicesOp : TTC_Op<"extract_indices", [NoMemoryEffect]> { + let summary = "Extract indices from a block pointer."; + + let description = [{ + Extract indices that can be used to access the block using its base memref. + Indices are supposed to be used for vector loads/stores with the base + memref extracted from the same block pointer. + }]; + + let arguments = (ins TT_TensorPtr:$src); + + let results = (outs Variadic:$result); + + let builders = [ + OpBuilder<(ins "Value":$src)> + ]; + + let hasCanonicalizer = 1; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTC_PtrToMemRefOp : TTC_Op<"ptr_to_memref", [NoMemoryEffect]> { + let summary = "Build a memref for a pointer."; + + let description = [{ + Build memref with static shape, offset, strides, and specified base pointer. + }]; + + let arguments = (ins TT_Ptr:$src); + + let results = (outs AnyStaticShapeMemRef:$result); + + let hasCanonicalizer = 0; + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} #endif diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 5c3aa2c1a827..83db4ae41607 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,4 @@ -add_subdirectory(TritonToTritonCPU) +#add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) -add_subdirectory(TritonCPUToLLVM) +#add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp index e28a65358dca..acd31c07290f 100644 --- a/lib/Dialect/TritonCPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -2,16 +2,20 @@ #include +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Transforms/DialectConversion.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/TypeSwitch.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.cpp.inc" + using namespace mlir; +using namespace mlir::triton; using namespace mlir::triton::cpu; //===----------------------------------------------------------------------===// @@ -20,6 +24,35 @@ using namespace mlir::triton::cpu; #define GET_ATTRDEF_CLASSES #include "triton/Dialect/TritonCPU/IR/TritonCPUAttrDefs.cpp.inc" +void ExtractMemRefOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) {} + +void ExtractIndicesOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) {} + +/// Parse an attribute registered to this dialect. +::mlir::Attribute +TritonCPUDialect::parseAttribute(::mlir::DialectAsmParser &parser, + ::mlir::Type type) const { + llvm_unreachable("parse stub called"); +} + +/// Print an attribute registered to this dialect. +void TritonCPUDialect::printAttribute(::mlir::Attribute attr, + ::mlir::DialectAsmPrinter &os) const { + llvm_unreachable("print stub called"); +} + +void ExtractIndicesOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, Value src) { + assert(triton::isTensorPointerType(src.getType()) && + "Unexecpeted source type"); + auto tensorTy = dyn_cast( + dyn_cast(src.getType()).getPointeeType()); + SmallVector resTypes(tensorTy.getRank(), builder.getIndexType()); + build(builder, state, resTypes, src); +} + void TritonCPUDialect::initialize() { registerTypes(); @@ -34,6 +67,9 @@ void TritonCPUDialect::initialize() { >(); } +#define GET_OP_CLASSES +#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" + // verify TritonCPU ops LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { diff --git a/python/src/llvm.cc b/python/src/llvm.cc index c86bf671a7df..1f7f9b03d676 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -3,6 +3,8 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/CodeGen/CommandFlags.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -21,6 +23,7 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/TargetParser/Host.h" #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/Instrumentation/AddressSanitizer.h" @@ -403,6 +406,71 @@ void init_triton_llvm(py::module &&m) { py::arg("flags") = std::vector{}, py::arg("enable_fp_fusion") = false); + m.def("set_host_target", [](llvm::Module *mod) { + mod->setTargetTriple(llvm::sys::getDefaultTargetTriple()); + std::string error; + auto target = + llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error); + std::unique_ptr machine{target->createTargetMachine( + mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {}, + llvm::Reloc::PIC_)}; + mod->setDataLayout(machine->createDataLayout()); + }); + + m.def( + "translate_to_host_asm", + [](std::string llvmIR, bool enable_fp_fusion) -> py::object { + std::string res; + { + // when allow_threads goes out of scope, gil will be released + py::gil_scoped_release allow_threads; + // create LLVM module from C++ + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + res = + translateLLVMIRToASM(*module, llvm::sys::getDefaultTargetTriple(), + llvm::sys::getHostCPUName().str(), "", {}, + enable_fp_fusion, false); + } + return py::str(res); + }, + ret::take_ownership); + + m.def( + "translate_to_bc", + [](const std::string llvmIR) -> py::object { + py::gil_scoped_release allow_threads; + // create LLVM module + llvm::LLVMContext context; + std::unique_ptr buffer = + llvm::MemoryBuffer::getMemBuffer(llvmIR.c_str()); + llvm::SMDiagnostic error; + std::unique_ptr module = + llvm::parseIR(buffer->getMemBufferRef(), error, context); + if (!module) { + llvm::report_fatal_error( + "failed to parse IR: " + error.getMessage() + + "lineno: " + std::to_string(error.getLineNo())); + } + // Write bitcode to a buffer. + llvm::SmallVector buf; + llvm::BitcodeWriter writer(buf); + writer.writeModule(*module); + writer.writeStrtab(); + std::string bitcode(buf.begin(), buf.end()); + return py::bytes(bitcode); + }, + ret::take_ownership); + m.def( "translate_to_asm", [](std::string llvmIR, std::string triple, std::string proc, diff --git a/python/src/passes.cc b/python/src/passes.cc index c365aaf43589..9e34f6ad7fed 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -45,8 +45,8 @@ void init_triton_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", createConvertTritonToTritonGPUPass, const std::string &, int, int, int); - ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", - createConvertTritonToTritonCPUPass); + // ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", + // createConvertTritonToTritonCPUPass); } void init_triton_passes_ttgpuir(py::module &&m) { @@ -89,6 +89,7 @@ void init_triton_passes_convert(py::module &&m) { ADD_PASS_WRAPPER_0("add_cf_to_llvmir", createConvertControlFlowToLLVMPass); ADD_PASS_WRAPPER_0("add_index_to_llvmir", createConvertIndexToLLVMPass); ADD_PASS_WRAPPER_0("add_arith_to_llvmir", createArithToLLVMConversionPass); + ADD_PASS_WRAPPER_0("add_math_to_llvmir", createConvertMathToLLVMPass); } void init_triton_passes_llvmir(py::module &&m) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a66737989d07..697df74ab7be 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -305,6 +305,7 @@ def filter_layouts(layouts): return [l for l in layouts if is_layout_applicable(l)] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) def test_empty_kernel(dtype_x, device): @@ -544,6 +545,7 @@ def test_dtype_codegen(): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -604,6 +606,7 @@ def promote_to_fp32(dtype_x, dtype_y): test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) def test_addptr(dtype, order, device): @@ -630,6 +633,7 @@ def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): np.testing.assert_allclose(y, to_numpy(y_tri)) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y", [ # (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes @@ -650,6 +654,7 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas) +@pytest.mark.cpu def test_unsigned_name_mangling(device): # Test that uint32 and int32 are mangled differently by the compiler SIZE = 128 @@ -686,6 +691,7 @@ def kernel(O1, O2, X, Y, SIZE: tl.constexpr): # test bitwise ops # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) @@ -710,6 +716,7 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes @@ -732,6 +739,7 @@ def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): ops = ['==', '!=', '>', '<', '>=', '<='] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "dtype_x, dtype_y, op, mode_x, mode_y", @@ -756,6 +764,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # --------------- # test broadcast # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) def test_broadcast(dtype, device): @@ -790,6 +799,7 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con # ---------- +@pytest.mark.cpu @pytest.mark.interpreter def test_slice(device): @@ -821,6 +831,7 @@ def slice_kernel(XBLOCK: tl.constexpr): # ------------------ +@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_slice(device): dst = torch.empty(128, device=device) @@ -836,6 +847,7 @@ def _kernel(dst): # ---------------- # test expand_dims # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims(device): @@ -884,6 +896,7 @@ def expand_dims_kernel(dummy, N: tl.constexpr): expand_dims_kernel[(1, )](dummy_tensor, N) +@pytest.mark.cpu @pytest.mark.interpreter def test_expand_dims_error_cases(device): @@ -947,6 +960,7 @@ def duplicate_dim2(dummy, N: tl.constexpr): # ---------------------------- # test invalid program id axis # ---------------------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_invalid_pid_axis(device): dst = torch.empty(128, device=device) @@ -963,6 +977,7 @@ def _kernel(dst): # --------------- # test where # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1015,6 +1030,7 @@ def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl. assert (z == to_numpy(z_tri)).all() +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_where_broadcast(num_ctas, device): @@ -1059,6 +1075,7 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr", [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') @@ -1073,6 +1090,7 @@ def test_unary_op(dtype_x, expr, num_ctas, device): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) @@ -1083,6 +1101,7 @@ def test_math_op(dtype_x, expr, x, device): _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_erf_op(dtype, device): @@ -1104,6 +1123,7 @@ def kernel(Z, X, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [dtype for dtype in ["float32", "float64"]]) def test_math_fma_op(dtype, device): @@ -1129,6 +1149,7 @@ def kernel(Z, X, Y, W, SIZE: tl.constexpr): torch.testing.assert_close(z_tri, z_ref) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr", ["tl.math.fdiv(x, y)", "tl.math.div_rn(x, y)"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1141,6 +1162,7 @@ def test_math_divide_op(expr, num_ctas, device): # ------------- # test precise math # ------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("expr_prec, expr_ref", [('tl.math.sqrt_rn(x)', 'tl.math.sqrt(x.to(tl.float64)).to(tl.float32)'), @@ -1181,6 +1203,7 @@ def kernel(X, Y, OUT, OUT_REF, BLOCK: tl.constexpr): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_abs(dtype_x, device): @@ -1226,6 +1249,7 @@ def abs_kernel(X, Z, SIZE: tl.constexpr): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_shapes_as_params(device): @@ -1260,6 +1284,7 @@ def kernel(): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_transpose(dtype_x, device): @@ -1368,6 +1393,7 @@ def tuples_fn(a, b): a * b +@pytest.mark.cpu @pytest.mark.interpreter def test_tuples(device): @@ -1460,6 +1486,7 @@ def noinline_multi_values_fn(x, y, Z): tl.store(Z, z) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) def test_noinline(mode, device): @@ -1782,6 +1809,7 @@ def kernel_r(ptrs, BLOCK_SIZE: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ @@ -1916,6 +1944,7 @@ def kernel(X, Y, Z, N: tl.constexpr): assert z.unique().size(0) == z.size(0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", list(torch_dtypes)) @pytest.mark.parametrize("constant_field", ["value", "mask"]) @@ -1947,6 +1976,7 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, CONSTANT_FIELD: tl. assert torch.all(output == 0) +@pytest.mark.cpu def test_load_store_same_ptr(device): @triton.jit() @@ -1965,6 +1995,7 @@ def kernel(in_out_ptr): assert torch.all(x == 2) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ['int32']) def test_umulhi(dtype_str, device): @@ -2002,6 +2033,7 @@ def umulhi32(a, b): np.testing.assert_equal(z_ref, to_numpy(z_tri)) +@pytest.mark.cpu @pytest.mark.interpreter def test_join(device): @@ -2022,6 +2054,7 @@ def kernel(X, Y, Z, N: tl.constexpr): np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_join_scalars(device): @@ -2041,6 +2074,7 @@ def kernel(X, Y, Z): np.testing.assert_equal([42, 100], to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_join_with_mma(device): @@ -2079,6 +2113,7 @@ def kernel(Z, N: tl.constexpr): np.testing.assert_equal(to_numpy(z_ref), to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_interleave_scalars(device): @@ -2175,6 +2210,7 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): return output +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("in_dtype", [torch.float16, torch.bfloat16]) def test_convert_float16_to_float32(in_dtype, device): @@ -3229,6 +3265,7 @@ def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) # TODO: bfloat16 @@ -3288,6 +3325,7 @@ def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constex assert 'st.global.v4' in ptx +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ["int32", "int8"]) @pytest.mark.parametrize("shape", [(2, 4), (16, 16)]) @@ -3311,6 +3349,7 @@ def kernel(In, Out, in_shape1: tl.constexpr, in_shape2: tl.constexpr, ou_shape1: np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ["int32", "int8"]) @pytest.mark.parametrize("shape", [(2, 2, 8, 64), (4, 4, 4, 4)]) @@ -4123,6 +4162,7 @@ def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.co assert re.search(r"ttg.async_wait %.* {num = 2 : i32}", h.asm["ttgir"]) is not None +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) @pytest.mark.parametrize("shape", [(), (1, ), (128, )]) @@ -4162,6 +4202,7 @@ def kernel_dynamic(out, val, dtype: tl.constexpr): assert torch.all(out_dynamic == 2) +@pytest.mark.cpu @pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), ('float("-inf")', "f32"), ('float("nan")', "f32"), ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) @@ -4186,6 +4227,7 @@ def pass_const(a, b, choose_b): return a +@pytest.mark.cpu @pytest.mark.parametrize("choose_const", [True, False]) @pytest.mark.parametrize("constexpr", [True, False]) @pytest.mark.parametrize("mode", ["direct", "call", "ternary", "if"]) @@ -4279,6 +4321,7 @@ def _kernel(out): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("start", [0, 1, 7, 16]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -4302,6 +4345,7 @@ def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str, size, size_diff, other", [(dtype_str, size, size_diff, other) for dtype_str in torch_dtypes @@ -4340,6 +4384,7 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): torch.testing.assert_close(output, reference_out) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) @pytest.mark.parametrize("mask_val", [True, False]) @@ -4639,6 +4684,7 @@ def _impl(value=10): return value +@pytest.mark.cpu @pytest.mark.interpreter def test_default(device): value = 5 @@ -4664,6 +4710,7 @@ def _kernel(ret0, ret1, value=3): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_noop(device): @@ -4691,6 +4738,7 @@ def kernel(x): kernel[(1, )](x) +@pytest.mark.cpu @pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) @@ -4716,6 +4764,7 @@ def kernel(value1, is_one, X): # -------------------- +@pytest.mark.cpu @pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: @@ -4737,6 +4786,7 @@ def kernel(VALUE, X): # ---------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @@ -4774,6 +4824,7 @@ def kernel(Z, X, Y): np.testing.assert_allclose(z, to_numpy(z_tri), rtol=1e-3) +@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_shape(device): @@ -4787,6 +4838,7 @@ def kernel(X): np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) +@pytest.mark.cpu @pytest.mark.interpreter def test_constexpr_scalar_shape(device): @@ -4804,6 +4856,7 @@ def kernel(X, s): reshape_list = [((64, ), (8, 8)), ((2, 32), (16, 4)), ((512, ), (2, 2, 2, 2, 2, 2, 2, 2, 2)), ((64, 32), (16, 8, 16))] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("formats", reshape_list) def test_reshape(formats, device): @@ -4831,6 +4884,7 @@ def generate_kernel(shape_x, shape_z): np.testing.assert_equal(z, to_numpy(z_tri)) +@pytest.mark.cpu def test_reshape_err(device): @triton.jit @@ -4876,6 +4930,7 @@ def kernel(ptr): assert "int16 tensor descriptor block shape must have at least 16 columns" in str(e.value.__cause__) +@pytest.mark.cpu def test_trans_reshape(device): @triton.jit @@ -4902,8 +4957,9 @@ def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.con actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) k = kernel[(1, )](input, actual, shape[0], shape[1]) - assert k.asm['ttgir'].count( - 'ttg.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + if not is_cpu(): + assert k.asm['ttgir'].count( + 'ttg.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" np.testing.assert_equal(to_numpy(expected), to_numpy(actual)) @@ -4937,6 +4993,7 @@ def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): tl.store(ptr + offsets, vec, mask=mask) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("type", ["inline", "noinline"]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -4968,6 +5025,7 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): # ------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("if_type", [ "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", @@ -5028,6 +5086,7 @@ def _kernel(dst): _kernel[(1, )](dst=dst, num_warps=4) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("func_str", ['sqrt', 'rsqrt', 'exp', 'exp2', 'log', 'log2', 'sin', 'cos']) def test_unary_math(func_str, device): @@ -5261,6 +5320,7 @@ def kernel(A, B, C, D, BLOCK: tl.constexpr): # ----------------------- +@pytest.mark.cpu @pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) def test_for_iv(lo, hi, iv, device): @@ -5280,6 +5340,7 @@ def kernel(Out, lo, hi, iv: tl.constexpr): assert out[0] == sum(range(lo, hi, iv)) +@pytest.mark.cpu @pytest.mark.interpreter def test_if_else(device): @@ -5305,6 +5366,7 @@ def kernel(Cond, TrueVal, FalseVal, Out): assert to_numpy(out)[0] == false_val[0] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("mode", ["dynamic", "static"]) def test_if_return(mode, device): @@ -5364,6 +5426,7 @@ def add_fn_static_cond(x, cond: tl.constexpr): return x + 1 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "call_type", @@ -5433,6 +5496,7 @@ def kernel(Out, call_type: tl.constexpr): assert to_numpy(out)[0] == 1 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("_cond1", [True, False]) @pytest.mark.parametrize("_cond2", [True, False]) @@ -5475,6 +5539,7 @@ def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): assert out[0] == targets[(_cond1, _cond2, _cond3)] +@pytest.mark.cpu @pytest.mark.interpreter def test_while(device): @@ -5503,6 +5568,7 @@ def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): assert out_j[0] == bound[0] +@pytest.mark.cpu @pytest.mark.interpreter def test_nested_while(device): @@ -6224,6 +6290,7 @@ def test_convert_warp_local(M, N, src_layout, dst_layout, dtype, device, tmp_pat torch.testing.assert_close(z, x, rtol=0, atol=0) +@pytest.mark.cpu @pytest.mark.interpreter def test_load_scalar_with_mask(device): @@ -6242,6 +6309,7 @@ def kernel(Input, Index, Out, N: int): # This test is used to test our own PTX codegen for float16 and int16 conversions # maybe delete it later after ptxas has been fixed +@pytest.mark.cpu @pytest.mark.parametrize("dtype_str", ['float16', 'int16']) def test_ptx_cast(dtype_str, device): @@ -6444,6 +6512,7 @@ def simple(data, out): # ----------------------- +@pytest.mark.cpu @pytest.mark.parametrize("dtype", ['float16', 'float32']) @pytest.mark.parametrize("propagate_nan", ['NONE', 'ALL']) @pytest.mark.parametrize("func", ['minimum', 'maximum', 'clamp']) @@ -6482,6 +6551,7 @@ def kernel(A, B, C, propagate_nan: tl.constexpr, func: tl.constexpr): # ----------------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", ['float16', 'float32']) def test_clamp(dtype, device): @@ -6518,6 +6588,7 @@ def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexp # Test for symmetric clamp(x, -limit, limit), as it may go through optimized # codegen in the backends +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", ['bfloat16', 'float16', 'float32']) def test_clamp_symmetric(dtype, device): @@ -6553,6 +6624,7 @@ def kernel(x_ptr, limit_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexpr): # ----------------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_static_range(device): diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 683889547b0a..1b08addbc9b7 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,3 +1,8 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) +add_subdirectory(include) +add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM) + target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm) endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 3c293cdf468f..344cdd2f05ae 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -4,7 +4,7 @@ import re from dataclasses import dataclass -from typing import Any +from typing import Any, Tuple from triton._C.libtriton import cpu, ir, llvm, passes from triton.backends.compiler import BaseBackend, GPUTarget @@ -20,6 +20,9 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False + allowed_dot_input_precisions: Tuple[str] = ("ieee",) + allow_fp8e4nv: bool = False + enable_fp_fusion: bool = True # TODO: We may introduce CPU-specific options like # of cores. @@ -40,7 +43,7 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) - self.binary_ext = "exe" + self.binary_ext = "bc" def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -62,7 +65,6 @@ def make_ttir(mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() passes.common.add_inliner(pm) - passes.ttir.add_rewrite_tensor_pointer(pm) passes.ttir.add_combine(pm) passes.common.add_canonicalizer(pm) passes.ttir.add_reorder_broadcast(pm) @@ -77,33 +79,36 @@ def make_ttcir(mod, metadata, opt): # TTIR -> TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - passes.ttir.add_convert_to_ttcpuir(pm) - - # - # TODO: - # - + cpu.passes.ttcpuir.add_triton_to_triton_cpu_pipeline(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) pm.run(mod) + metadata["cluster_dims"] = (opt.cluster_dims[0], opt.cluster_dims[1], opt.cluster_dims[2]) return mod @staticmethod def make_llir(src, metadata, options): + # warp-specialization mutates num_warps + num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") + if num_warp_groups is not None: + metadata["num_warps"] *= num_warp_groups + metadata["threads_per_warp"] = 1 mod = src # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False) + cpu.passes.ttcpuir.add_lower_affine(pm) passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) - - cpu.passes.ttcpuir.add_to_llvmir(pm) - passes.common.add_canonicalizer(pm) - passes.common.add_cse(pm) - - passes.convert.add_scf_to_cf(pm) - passes.convert.add_cf_to_llvmir(pm) + cpu.passes.ttcpuir.add_triton_cpu_to_llvmir_pipeline(pm) + passes.convert.add_math_to_llvmir(pm) + cpu.passes.ttcpuir.add_math_to_libm(pm) + cpu.passes.ttcpuir.add_vector_to_llvmir(pm) + cpu.passes.ttcpuir.add_memref_to_llvmir(pm) passes.convert.add_arith_to_llvmir(pm) + cpu.passes.ttcpuir.add_func_to_llvmir(pm) passes.common.add_canonicalizer(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) @@ -111,45 +116,40 @@ def make_llir(src, metadata, options): passes.llvmir.add_di_scope(pm) pm.run(mod) + # Find kernel fn + kernel_names = cpu.find_kernel_names(mod) + assert len(kernel_names) == 1, f"expected exactly 1 kernel in a module, got {kernel_names}" + # LLVM-IR (MLIR) -> LLVM-IR (LLVM) llvm.init_targets() context = llvm.context() llvm_mod = llvm.to_module(mod, context) - - # TODO: - if not llvm_mod: - metadata["shared"] = 0 - return src - - if options.extern_libs: - paths = [path for (name, path) in options.extern_libs] - llvm.link_extern_libs(llvm_mod, paths) + llvm.set_host_target(llvm_mod) + #if options.extern_libs: + # paths = [path for (name, path) in options.extern_libs] + # llvm.link_extern_libs(llvm_mod, paths) llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3) - - # CPU doesn't have SMEM, but just to make it work for now. + # Get some metadata metadata["shared"] = 0 - - # Cleanup + metadata["name"] = kernel_names[0] ret = str(llvm_mod) del llvm_mod del context return ret @staticmethod - def make_exe(src, metadata, options): - # Just a quick hack while developing the backend. - names = re.findall(r"\s+define void @([a-zA-Z_][a-zA-Z0-9_]*)\(", str(src)) - assert len(names) == 1 - metadata["name"] = names[0] - - # TODO: Call llc to create an executable. - return src + def make_bc(src, metadata, options): + if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1": + print("********** Module ASM **********") + print(llvm.translate_to_host_asm(src, options.enable_fp_fusion)) + ret = llvm.translate_to_bc(src) + return ret def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) - stages["exe"] = lambda src, metadata: self.make_exe(src, metadata, options) + stages["bc"] = lambda src, metadata: self.make_bc(src, metadata, options) @functools.lru_cache() def hash(self): diff --git a/third_party/cpu/backend/driver.cpp b/third_party/cpu/backend/driver.cpp new file mode 100644 index 000000000000..babff3dfdebe --- /dev/null +++ b/third_party/cpu/backend/driver.cpp @@ -0,0 +1,224 @@ +//===- driver.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/TargetSelect.h" + +#include +#include +#include +#include +#include +#include +#include + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + + return Py_BuildValue("{s:i}", "max_shared_mem", 0); +} + +bool getBoolEnv(const std::string &env) { + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return std::tolower(c); }); + return (str == "on" || str == "true" || str == "1"); +} + +llvm::orc::ThreadSafeContext &getThreadSafeContext() { + static llvm::orc::ThreadSafeContext tsc; + static std::once_flag init_flag; + std::call_once(init_flag, []() { + auto context = std::make_unique(); + tsc = llvm::orc::ThreadSafeContext(std::move(context)); + }); + return tsc; +} + +std::string llvmErrToString(const llvm::Error &err) { + std::string res; + llvm::raw_string_ostream os(res); + os << err; + return res; +}; + +struct CompiledKernel { + std::unique_ptr execution_session; + std::unique_ptr data_layout; + std::unique_ptr mangle; + std::unique_ptr object_layer; + std::unique_ptr compiler_layer; + llvm::orc::JITDylib *dylib = nullptr; + + CompiledKernel() = default; + CompiledKernel(CompiledKernel &&) = default; + + ~CompiledKernel() { + if (execution_session) + llvm::cantFail(execution_session->endSession()); + } +}; + +std::vector> compiled_kernels; + +static PyObject *loadBitcode(PyObject *self, PyObject *args) { + const char *name; + int shared; + PyObject *py_bytes; + int devId; + + if (!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &devId)) { + std::cerr << "loadBitcode arg parse failed" << std::endl; + return NULL; + } + + std::string kernel_name = name; + size_t binary_size = PyBytes_Size(py_bytes); + const char *binary_ptr = PyBytes_AsString(py_bytes); + + llvm::LLVMContext context; + auto buf = llvm::MemoryBuffer::getMemBuffer( + llvm::StringRef(binary_ptr, binary_size)); + auto mod = llvm::parseBitcodeFile(*buf, context); + if (!mod) { + std::cerr << "Failed to parse LLVM bitcode module" << std::endl; + return NULL; + } + + if (getBoolEnv("MLIR_ENABLE_DUMP")) { + llvm::errs() << "********** Loaded Module (kernel_name=" << name + << ") **********\n" + << **mod << "\n"; + } + + auto init_err = llvm::InitializeNativeTarget(); + if (init_err) { + std::cerr << "Failed to initialize native target." << std::endl; + return NULL; + } + + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + + auto self_epc = + llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create()); + + auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!detect_host_res) { + std::cerr << "Failed to initialize JITTargetMachineBuilder: " + << llvmErrToString(detect_host_res.takeError()); + return NULL; + } + llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); + + auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); + if (!data_layout_res) { + std::cerr << "Failed to initialize data layout: " + << llvmErrToString(data_layout_res.takeError()); + return NULL; + } + + CompiledKernel kernel; + kernel.execution_session = + std::make_unique(std::move(self_epc)); + kernel.data_layout = + std::make_unique(std::move(*data_layout_res)); + kernel.mangle = std::make_unique( + *kernel.execution_session, *kernel.data_layout); + kernel.object_layer = std::make_unique( + *kernel.execution_session, + []() { return std::make_unique(); }); + kernel.compiler_layer = std::make_unique( + *kernel.execution_session, *kernel.object_layer, + std::make_unique(std::move(tmb))); + + auto dylib_res = kernel.execution_session->createJITDylib("
"); + if (!dylib_res) { + std::cerr << "Failed to create initialize JITDylib: " + << llvmErrToString(dylib_res.takeError()); + return NULL; + } + + kernel.dylib = &(*dylib_res); + kernel.dylib->addGenerator(llvm::cantFail( + llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( + kernel.data_layout->getGlobalPrefix()))); + + // Compile module. + (**mod).setDataLayout(*kernel.data_layout); + llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); + auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm)); + if (err) { + std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); + return NULL; + } + + // Find kernel function pointer. + auto lookup_res = + kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name)); + if (!lookup_res) { + std::cerr << "Failed to find function " << std::string(name) + << "\nError: " << llvmErrToString(lookup_res.takeError()); + return NULL; + } + uint64_t fn_ptr = lookup_res->getAddress().getValue(); + + compiled_kernels.push_back( + std::make_unique(std::move(kernel))); + auto *kernel_ptr = compiled_kernels.back().get(); + + return Py_BuildValue("(KKii)", reinterpret_cast(kernel_ptr), + reinterpret_cast(fn_ptr), 0, 0); +} + +static PyObject *initContext(PyObject *self, PyObject *args) { + return Py_BuildValue("(K)", (uint64_t)0); +} + +static PyObject *initDevices(PyObject *self, PyObject *args) { + return Py_BuildValue("(i)", 1); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBitcode, METH_VARARGS, + "Load provided SPV into ZE driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cpu_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_cpu_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; +} diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 3f3816a99b9f..3fe243fc262d 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,5 +1,100 @@ +import os +import hashlib +import tempfile +from pathlib import Path +from triton.runtime.build import _build +from triton.runtime.cache import get_cache_manager +from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget -from triton.backends.driver import CPUDriverBase + +dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") +llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm") +llvm_root = os.path.expanduser(llvm_root) +llvm_dirs = os.listdir(llvm_root) +if len(llvm_dirs) == 1: + llvm_root = os.path.join(llvm_root, llvm_dirs[0]) +include_dir = [ + os.path.join(dirname, "include"), + os.path.join(llvm_root, "include"), +] +library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] +libraries = [ + "LLVMOrcJIT", + "LLVMPasses", + "LLVMX86CodeGen", + "LLVMX86AsmParser", + "LLVMX86Desc", + "LLVMX86Info", + "LLVMGlobalISel", + "LLVMSelectionDAG", + "LLVMHipStdPar", + "LLVMCoroutines", + "LLVMipo", + "LLVMFrontendOpenMP", + "LLVMInstrumentation", + "LLVMAsmPrinter", + "LLVMCodeGen", + "LLVMObjCARCOpts", + "LLVMLinker", + "LLVMVectorize", + "LLVMScalarOpts", + "LLVMInstCombine", + "LLVMFrontendOffloading", + "LLVMExecutionEngine", + "LLVMAggressiveInstCombine", + "LLVMTransformUtils", + "LLVMTarget", + "LLVMRuntimeDyld", + "LLVMJITLink", + "LLVMIRPrinter", + "LLVMBitWriter", + "LLVMAnalysis", + "LLVMProfileData", + "LLVMSymbolize", + "LLVMDebugInfoDWARF", + "LLVMObject", + "LLVMTextAPI", + "LLVMMCParser", + "LLVMMCDisassembler", + "LLVMMC", + "LLVMIRReader", + "LLVMCFGuard", + "LLVMBitReader", + "LLVMAsmParser", + "LLVMCore", + "LLVMBinaryFormat", + "LLVMOrcTargetProcess", + "LLVMTargetParser", + "LLVMRemarks", + "LLVMOrcShared", + "LLVMOption", + "LLVMDebugInfoCodeView", + "LLVMCodeGenTypes", + "LLVMBitstreamReader", + "LLVMSupport", + "LLVMDemangle", + "stdc++", +] + + +def compile_module_from_src(src, name): + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.cpp") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location(name, cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + # ------------------------ # Utils @@ -15,22 +110,12 @@ def __new__(cls): def __init__(self): pass + dirname = os.path.dirname(os.path.realpath(__file__)) + mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils") + self.load_binary = mod.load_binary - @staticmethod - def get_device_properties(device): - # This is just dummy for now. We will need to implement driver.c. - return { - "max_shared_mem": 0, - "multiprocessor_count": 0, - "sm_clock_rate": 0, - "mem_clock_rate": 0, - "mem_bus_width": 0, - } - - @staticmethod - def load_binary(name, kernel_asm, shared, device): - # This is just dummy for now. We will need to implement driver.c. - return (None, kernel_asm, 0, 0) + def get_device_properties(self, *args): + return {"max_shared_mem": 0} # ------------------------ @@ -38,27 +123,229 @@ def load_binary(name, kernel_asm, shared, device): # ------------------------ +def ty_to_cpp(ty): + if ty[0] == '*': + return "void*" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u1": "uint32_t", + "u8": "uint8_t", + "u16": "uint16_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + def make_launcher(constants, signature, ids): - pass + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors. + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return ty_to_cpp(ty) + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "int8_t": "b", + "int16_t": "h", + "int32_t": "i", + "int64_t": "l", + "uint8_t": "B", + "uint16_t": "H", + "uint32_t": "I", + "uint64_t": "K", + }[ty] + + args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + format = "iiiOKOOOO" + args_format + arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + kernel_fn_args = [i for i in signature.keys() if i not in constants] + kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else '' + kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" + + # generate glue code + src = f""" +#include +#include +#include +#include + +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include +#include + +using kernel_ptr_t = void(*)({kernel_fn_arg_types}); + +typedef struct _DevicePtrInfo {{ + void* dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = (void*) PyLong_AsLongLong(ret); + if(!ptr_info.dev_ptr) {{ + return ptr_info; + }} + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + ptr_info.valid = false; + return ptr_info; +}} + +static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + // TODO: add OMP pragmas to run in parallel + for (uint32_t z = 0; z < gridZ; ++z) {{ + for (uint32_t y = 0; y < gridY; ++y) {{ + for (uint32_t x = 0; x < gridX; ++x) {{ + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); + }} + }} + }} +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + + + int gridX, gridY, gridZ; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *kernel_metadata = NULL; + PyObject *launch_metadata = NULL; + PyObject *py_obj_stream; + void* pKrnl; + + {' '.join([f"{_extracted_type(ty)} arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &py_obj_stream, &pKrnl, + &kernel_metadata, &launch_metadata, + &launch_enter_hook, &launch_exit_hook {', ' + arg_ptrs_list if len(signature) > 0 else ''})) {{ + return NULL; + }} + + void *pStream = PyLong_AsVoidPtr(py_obj_stream); + kernel_ptr_t kernel_ptr = reinterpret_cast(pKrnl); + + // extract launch metadata + if (launch_enter_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_enter_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + run_omp_kernels(gridX, gridY, gridZ, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + + if(launch_exit_hook != Py_None){{ + PyObject* args = Py_BuildValue("(O)", launch_metadata); + PyObject* ret = PyObject_CallObject(launch_exit_hook, args); + Py_DECREF(args); + if (!ret) + return NULL; + }} + + if (PyErr_Occurred()) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_cpu_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_cpu_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src class CPULauncher(object): def __init__(self, src, metadata): - # TODO: - self.launch = lambda *args, **kwargs: None + ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()} + constants = src.constants if hasattr(src, "constants") else dict() + cst_key = lambda i: src.fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + signature = {cst_key(key): value for key, value in src.signature.items()} + src = make_launcher(constants, signature, ids) + mod = compile_module_from_src(src, "__triton_cpu_launcher") + self.launch = mod.launch def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) -class CPUDriver(CPUDriverBase): +class CPUDriver(DriverBase): def __init__(self): self.utils = CPUUtils() self.launcher_cls = CPULauncher super().__init__() + def get_current_device(self): + return 0 + + def get_current_stream(self, device): + return 0 + def get_current_target(self): # Capability and warp size are zeros for CPU. # TODO: GPUTarget naming isn't obviously good. diff --git a/third_party/cpu/include/Analysis/TensorPtrShapeInfo.h b/third_party/cpu/include/Analysis/TensorPtrShapeInfo.h new file mode 100644 index 000000000000..838ecebb6add --- /dev/null +++ b/third_party/cpu/include/Analysis/TensorPtrShapeInfo.h @@ -0,0 +1,107 @@ +#ifndef TRITON_CPU_ANALYSIS_TENSORPTRSHAPEINFO_H +#define TRITON_CPU_ANALYSIS_TENSORPTRSHAPEINFO_H + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Support/LLVM.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#include +#include + +namespace mlir::triton::cpu { + +// Lattice value to hold a shape and strides for a tensor pointer. +// If multiple size or stride values are possible for some dimension +// then ShapedType::kDynamic is used for that dimension. +class TensorPtrShapeInfo { +public: + TensorPtrShapeInfo() = default; + + TensorPtrShapeInfo(ArrayRef shape, ArrayRef strides) + : shape(shape), strides(strides) { + assert(shape.size() == strides.size()); + } + + ArrayRef getShape() const { return shape; } + ArrayRef getStrides() const { return strides; } + + int64_t getRank() const { return static_cast(shape.size()); } + int64_t getSize(int64_t dim) const { return shape[dim]; } + int64_t getStride(int64_t dim) const { return strides[dim]; } + + bool operator==(const TensorPtrShapeInfo &other) const { + return shape == other.shape && strides == other.strides; + } + + static TensorPtrShapeInfo join(const TensorPtrShapeInfo &lhs, + const TensorPtrShapeInfo &rhs); + + static TensorPtrShapeInfo getPessimisticValueState(Value value); + + void print(raw_ostream &os) const { + os << "shape = ["; + llvm::interleaveComma(shape, os); + os << "], strides = ["; + llvm::interleaveComma(strides, os); + os << "]"; + } + +private: + SmallVector shape; + SmallVector strides; +}; + +using TensorPtrShapeInfoMapT = DenseMap; +class ModuleTensorPtrShapeInfoAnalysis + : public CallGraph { +public: + explicit ModuleTensorPtrShapeInfoAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, TensorPtrShapeInfoMapT{}); + }); + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = dyn_cast(callOp.resolveCallable()); + update(callOp, callee); + }); + } + } + + TensorPtrShapeInfo *getPtrShapeInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } + +private: + void initialize(FunctionOpInterface funcOp); + void update(CallOpInterface callOp, FunctionOpInterface funcOp); +}; + +} // namespace mlir::triton::cpu + +#endif // TRITON_CPU_ANALYSIS_TENSORPTRSHAPEINFO_H diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt new file mode 100644 index 000000000000..fc9a19e52b0d --- /dev/null +++ b/third_party/cpu/include/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..64b36523d35d --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) +add_public_tablegen_target(TritonCPUConversionPassIncGen) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h new file mode 100644 index 000000000000..74f74b00870c --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -0,0 +1,36 @@ +#ifndef TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H +#define TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" + +std::unique_ptr> createFuncOpToLLVMPass(); +std::unique_ptr> createMemoryOpToLLVMPass(); +std::unique_ptr> createGetProgramIdOpToLLVMPass(); + +void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); +void registerTritonCPUToLLVMPipeline(); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td new file mode 100644 index 000000000000..c75b58b572f1 --- /dev/null +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -0,0 +1,46 @@ +#ifndef TRITONCPU_CONVERSION_PASSES +#define TRITONCPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def FuncOpToLLVM : Pass<"triton-cpu-func-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert FuncOp to LLVM for CPU."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createFuncOpToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def MemoryOpToLLVM : Pass<"triton-cpu-memory-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton memory operations to LLVM for CPU."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createMemoryOpToLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::scf::SCFDialect", + "mlir::memref::MemRefDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def GetProgramIdOpToLLVM : Pass<"triton-cpu-get-program-id-op-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton GetProgramId to LLVM for CPU."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createGetProgramIdOpToLLVMPass()"; + + let dependentDialects = ["mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect"]; +} + +#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..56e231273ed6 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonCPU) +add_public_tablegen_target(TritonToTritonCPUPassIncGen) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h new file mode 100644 index 000000000000..745799039691 --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -0,0 +1,38 @@ +#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES_H +#define TRITONTOTRITONCPU_CONVERSION_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" + +std::unique_ptr> createConvertElementwiseOps(); +std::unique_ptr> createConvertMemoryOps(); +std::unique_ptr> createConvertPtrOps(); +std::unique_ptr> createConvertDotOp(); +std::unique_ptr> createConvertControlFlowOps(); + +void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); +void registerTritonToTritonCPUPipeline(); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td new file mode 100644 index 000000000000..5f52f3a2e31d --- /dev/null +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -0,0 +1,77 @@ +#ifndef TRITONTOTRITONCPU_CONVERSION_PASSES +#define TRITONTOTRITONCPU_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertMemoryOps : Pass<"triton-cpu-convert-memory-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton memory ops."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertMemoryOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertElementwiseOps : Pass<"triton-cpu-convert-elementwise-ops", "mlir::ModuleOp"> { + let summary = "Convert elementwise ops."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertElementwiseOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertPtrOps : Pass<"triton-cpu-convert-ptr-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton ops related to pointer arithmetics."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertPtrOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDotOp : Pass<"triton-cpu-convert-dot-op", "mlir::ModuleOp"> { + let summary = "Convert Triton DotOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertDotOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertControlFlowOps : Pass<"triton-cpu-convert-control-flow-op", "mlir::ModuleOp"> { + let summary = "Convert Triton DotOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertControlFlowOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +#endif diff --git a/third_party/cpu/lib/Analysis/CMakeLists.txt b/third_party/cpu/lib/Analysis/CMakeLists.txt new file mode 100644 index 000000000000..d0ac08b9daf0 --- /dev/null +++ b/third_party/cpu/lib/Analysis/CMakeLists.txt @@ -0,0 +1,11 @@ +add_triton_library(TritonCPUAnalysis + TensorPtrShapeInfo.cpp + + DEPENDS + TritonCPUTableGen + + LINK_LIBS PUBLIC + MLIRAnalysis + TritonIR + TritonCPUIR +) diff --git a/third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp b/third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp new file mode 100644 index 000000000000..bd3959e051f0 --- /dev/null +++ b/third_party/cpu/lib/Analysis/TensorPtrShapeInfo.cpp @@ -0,0 +1,219 @@ +#include "cpu/include/Analysis/TensorPtrShapeInfo.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir::triton::cpu { + +TensorPtrShapeInfo TensorPtrShapeInfo::join(const TensorPtrShapeInfo &lhs, + const TensorPtrShapeInfo &rhs) { + // If one argument is not initialized, return the other. + if (lhs.getRank() == 0) + return rhs; + if (rhs.getRank() == 0) + return lhs; + assert(lhs.getRank() == rhs.getRank()); + + SmallVector shape(lhs.getShape()); + SmallVector strides(lhs.getStrides()); + for (int64_t i = 0; i < lhs.getRank(); ++i) { + if (shape[i] != rhs.getSize(i)) + shape[i] = ShapedType::kDynamic; + if (strides[i] != rhs.getStride(i)) + strides[i] = ShapedType::kDynamic; + } + return TensorPtrShapeInfo(shape, strides); +} + +namespace { + +template +void initPessimisticStateFromFunc(int argNumber, T funcOp, + SmallVectorImpl &shape, + SmallVectorImpl &strides) { + auto loadFromAttr = [&](std::string_view attrName, + SmallVectorImpl &out) { + Attribute attr = funcOp.getArgAttr(argNumber, attrName); + if (auto dense_attr = dyn_cast_or_null(attr)) { + auto vals = dense_attr.getValues(); + out = SmallVector(vals.begin(), vals.end()); + } + }; + loadFromAttr("tt.shape", shape); + loadFromAttr("tt.strides", strides); +} + +TensorPtrShapeInfo getPessimisticValueState(Value value) { + int rank = 0; + if (triton::isTensorPointerType(value.getType())) + rank = cast(getPointeeType(value.getType())).getRank(); + + SmallVector shape; + SmallVector strides; + + BlockArgument blockArg = dyn_cast(value); + + if (blockArg && blockArg.getOwner()->isEntryBlock()) { + Operation *op = blockArg.getOwner()->getParentOp(); + if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, shape, + strides); + // llvm codegen check alignment to generate vector load/store + // would be nice if this wasn't the case + else if (auto fun = dyn_cast(op)) + initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, shape, + strides); + } else if (Operation *op = value.getDefiningOp()) { + if (isa(op)) { + // scf::ForOp, scf::IfOp, scf::WhileOp + // Control flow operations are initialized with "unknown" state. + } else { + // Other operations are conservatively initialized with dynamic + // shape and strides unless they have specified. + if (Attribute attr = op->getDiscardableAttr("tt.shape")) { + auto vals = cast(attr).getValues(); + shape = SmallVector(vals.begin(), vals.end()); + } else { + shape.insert(shape.end(), rank, ShapedType::kDynamic); + } + if (Attribute attr = op->getDiscardableAttr("tt.strides")) { + auto vals = cast(attr).getValues(); + strides = SmallVector(vals.begin(), vals.end()); + } else { + strides.insert(strides.end(), rank, ShapedType::kDynamic); + } + } + } + + return TensorPtrShapeInfo(shape, strides); +} + +class ShapeInfoAnalysis : public dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice> { +private: + void + setToEntryState(dataflow::Lattice *lattice) override { + propagateIfChanged( + lattice, lattice->join(getPessimisticValueState(lattice->getAnchor()))); + } + +public: + ShapeInfoAnalysis(DataFlowSolver &solver); + using dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; + using FuncShapeInfoMapT = DenseMap; + + LogicalResult visitOperation( + Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; +}; + +ShapeInfoAnalysis::ShapeInfoAnalysis(DataFlowSolver &solver) + : dataflow::SparseForwardDataFlowAnalysis< + dataflow::Lattice>(solver) {} + +SmallVector copyConstOrDynamic(OperandRange ops) { + SmallVector res; + for (auto op : ops) { + if (auto cstOp = op.getDefiningOp()) { + auto intAttr = dyn_cast(cstOp.getValue()); + assert(intAttr); + res.push_back(intAttr.getInt()); + } else { + res.push_back(ShapedType::kDynamic); + } + } + return res; +} + +LogicalResult ShapeInfoAnalysis::visitOperation( + Operation *op, + ArrayRef *> operands, + ArrayRef *> results) { + // TODO: For sure not the right way to do this + // but why is scf.if not initialized otherwise? + for (auto op : operands) + if (op->getValue().getRank() == 0) + setToEntryState((dataflow::Lattice *)op); + + TensorPtrShapeInfo res; + // Tensor pointers are only produced by MakeTensorPtrOp which has + // shape and strides as its args, and AdvanceOp which preserves + // shape and strides of the input pointer. + if (auto makePtrOp = dyn_cast(op)) { + SmallVector shape = copyConstOrDynamic(makePtrOp.getShape()); + SmallVector strides = copyConstOrDynamic(makePtrOp.getStrides()); + res = TensorPtrShapeInfo(shape, strides); + } else if (auto advOp = dyn_cast(op)) { + res = operands[0]->getValue(); + } + + // join all lattice elements + for (auto *result : results) + propagateIfChanged(result, result->join(res)); + + return success(); +} + +} // namespace + +void ModuleTensorPtrShapeInfoAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + ShapeInfoAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(funcOp))) + return; + auto *shapeInfoMap = getFuncData(funcOp); + auto updateShapeInfoMap = [&](Value value) { + auto shapeInfo = analysis->getLatticeElement(value)->getValue(); + TensorPtrShapeInfo curShapeInfo; + if (shapeInfoMap->count(value)) { + curShapeInfo = + TensorPtrShapeInfo::join(shapeInfo, shapeInfoMap->lookup(value)); + } else { + curShapeInfo = shapeInfo; + } + (*shapeInfoMap)[value] = curShapeInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateShapeInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateShapeInfoMap(value); + } + }); +} + +void ModuleTensorPtrShapeInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *shapeInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, ArrayRef value) { + SmallVector curValue(value); + if (auto attr = + callee.getArgAttrOfType(index, attrName)) { + auto oldValue = cast(attr).getValues(); + assert(oldValue.size() == curValue.size()); + for (size_t i = 0; i < curValue.size(); ++i) + if (curValue[i] != oldValue[i]) + curValue[i] = ShapedType::kDynamic; + } + auto attr = DenseElementsAttr::get( + VectorType::get(curValue.size(), + IntegerType::get(callee.getContext(), 64)), + ArrayRef(curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto shapeInfo = shapeInfoMap->lookup(value); + if (shapeInfo.getRank()) { + setAttrFn("tt.shape", shapeInfo.getShape()); + setAttrFn("tt.strides", shapeInfo.getStrides()); + } + } +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt new file mode 100644 index 000000000000..1db64c58ec20 --- /dev/null +++ b/third_party/cpu/lib/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(Analysis) +add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt new file mode 100644 index 000000000000..884c9352ef1b --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -0,0 +1,13 @@ +add_triton_library(TritonCPUToLLVM + FuncOpToLLVM.cpp + GetProgramIdOpToLLVM.cpp + MemoryOpToLLVM.cpp + Pipeline.cpp + TypeConverter.cpp + + DEPENDS + TritonCPUToLLVMConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRVectorToLLVMPass +) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp new file mode 100644 index 000000000000..4c5257fcff4c --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -0,0 +1,278 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_FUNCOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +namespace mlir { +FailureOr +convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, + ConversionPatternRewriter &rewriter, + const LLVMTypeConverter &converter); +} + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +struct FuncOpConversion : public ConvertOpToLLVMPattern { + FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + /// Only retain those attributes that are not constructed by + /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument + /// attributes. + static void filterFuncAttributes(triton::FuncOp op, bool filterArgAttrs, + SmallVectorImpl &result) { + + for (const auto &attr : op->getAttrs()) { + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == op.getFunctionTypeAttrName() || + attr.getName() == "std.varargs" || + (filterArgAttrs && attr.getName() == op.getArgAttrsAttrName())) + continue; + result.push_back(attr); + } + } + + triton::FuncOp amendProgramIdArgs(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + // 1. Modify the function type to add new arguments. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(i32_ty); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add new arguments. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + SmallVector amendedArgAttrs; + if (funcOp.getAllArgAttrs()) + amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + // 3. Add a new arguments to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(i32_ty, loc); + region.addArgument(i32_ty, loc); + region.addArgument(i32_ty, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } + + LogicalResult + matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Prevent LLVM's inliner to inline this function + auto modifiedFuncOp = funcOp; + if (LLVM::isKernel(funcOp)) + modifiedFuncOp = amendProgramIdArgs(modifiedFuncOp, rewriter); + + LLVM::LLVMFuncOp newFuncOp = *mlir::convertFuncOpToLLVMFuncOp( + modifiedFuncOp, rewriter, *getTypeConverter()); + if (!newFuncOp) + return failure(); + + // required by AxisInfoAnalysis + if (LLVM::isKernel(funcOp)) + rewriter.eraseOp(modifiedFuncOp); + rewriter.eraseOp(funcOp); + return success(); + } +}; + +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto funcOp = op->getParentOfType(); + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + auto loc = op.getLoc(); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = + insert_val(packedResultsTy, packedResults, it.value(), it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); + return success(); + } +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = callOp.getLoc(); + auto caller = callOp->getParentOfType(); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } +}; + +struct FuncOpToLLVM : public triton::impl::FuncOpToLLVMBase { + using FuncOpToLLVMBase::FuncOpToLLVMBase; + + FuncOpToLLVM() : FuncOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + // Lower tt.func + RewritePatternSet funcPatterns(context); + funcPatterns.add(typeConverter, + /*benefit=*/1); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + funcPatterns); + if (failed( + applyPartialConversion(mod, convTarget, std::move(funcPatterns)))) + return signalPassFailure(); + + // Lower tt.call, tt.return + int benefit = 10; + RewritePatternSet patterns(context); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createFuncOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp new file mode 100644 index 000000000000..4c593f1ff7aa --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp @@ -0,0 +1,98 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_GETPROGRAMIDOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// TODO: use enums to access struct fields. +struct GetProgramIdOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GetProgramIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + assert(funcOp && "expected LLVM::FuncOp as a parent of GetProgramIdOp"); + auto args = funcOp.getArguments(); + // Last three args are x, y, z program ids. + auto argIdx = args.size() - 3 + op.getAxisAsInt(); + assert(argIdx < args.size() && "out-of-bounds arg index"); + assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); + rewriter.replaceOp(op, args[argIdx]); + return success(); + } +}; + +struct GetProgramIdOpToLLVM + : public triton::impl::GetProgramIdOpToLLVMBase { + using GetProgramIdOpToLLVMBase::GetProgramIdOpToLLVMBase; + + GetProgramIdOpToLLVM() : GetProgramIdOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createGetProgramIdOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp new file mode 100644 index 000000000000..68d7231039c5 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp @@ -0,0 +1,353 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_MEMORYOPTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// TODO: use enums to access struct fields. +struct ExtractMemRefOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ExtractMemRefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); + auto memRefTy = cast(op.getType()); + auto rank = memRefTy.getRank(); + auto memRefStructTy = getTypeConverter()->convertType(op.getType()); + auto memRefStructFields = + cast(memRefStructTy).getBody(); + auto i64Ty = IntegerType::get(getContext(), 64); + + auto copyValue = [&](Value to, int64_t idxFrom, int64_t idxTo) { + auto valueTy = memRefStructFields[idxTo]; + Value val = rewriter.create( + loc, valueTy, tensorPtrStruct, idxFrom); + return rewriter.create(loc, memRefStructTy, to, val, + idxTo); + }; + + Value res = undef(memRefStructTy); + // Copy base. + res = copyValue(res, 0, 1); + // Use 0 offset. + res = rewriter.create(loc, memRefStructTy, res, + i64_val(0), 2); + // Copy shape. + res = copyValue(res, 2, 3); + // Copy strides. + res = copyValue(res, 3, 4); + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct ExtractIndicesOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ExtractIndicesOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); + auto rank = op.getNumResults(); + auto i64Ty = IntegerType::get(getContext(), 64); + SmallVector indices; + + for (int64_t i = 0; i < rank; i++) { + Value offs = rewriter.create( + loc, i64Ty, tensorPtrStruct, SmallVector{1, i}); + Value stride = rewriter.create( + loc, i64Ty, tensorPtrStruct, SmallVector{3, i}); + indices.push_back(rewriter.create(loc, offs, stride)); + } + + rewriter.replaceOp(op, indices); + + return success(); + } +}; + +struct PtrToMemRefOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(PtrToMemRefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getSrc()); + auto memRefStructTy = getTypeConverter()->convertType(op.getType()); + + Value res = undef(memRefStructTy); + res = + rewriter.create(loc, memRefStructTy, res, ptr, 1); + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct MakeTensorPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto structTy = getTypeConverter()->convertType(op.getType()); + auto i64Ty = IntegerType::get(getContext(), 64); + + auto insertArray = [&](Value structVal, auto values, int64_t idx, + Type zextTo = nullptr) { + for (int64_t i = 0; i < static_cast(values.size()); ++i) { + Value val = values[i]; + if (zextTo) + val = rewriter.create(loc, zextTo, val); + structVal = rewriter.create( + loc, structTy, structVal, val, SmallVector{idx, i}); + } + return structVal; + }; + + Value res = undef(structTy); + // 0 - base pointer. + auto base = rewriter.getRemappedValue(op.getBase()); + res = rewriter.create(loc, structTy, res, base, 0); + // 1 - array for offsets. Promote values to i64. + res = insertArray(res, op.getOffsets(), 1, i64Ty); + // 2 - array for shape. + res = insertArray(res, op.getShape(), 2); + // 3 - array for strides. + res = insertArray(res, op.getStrides(), 3); + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct AdvanceOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AdvanceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto i64Ty = IntegerType::get(getContext(), 64); + Value res = rewriter.getRemappedValue(op.getPtr()); + Type structTy = res.getType(); + auto offsets = op.getOffsets(); + + for (int64_t i = 0; i < offsets.size(); ++i) { + auto oldOffset = rewriter.create( + loc, i64Ty, res, SmallVector{1, i}); + auto step = rewriter.create(loc, i64Ty, offsets[i]); + auto newOffset = rewriter.create(loc, oldOffset, step); + res = rewriter.create(loc, structTy, res, newOffset, + SmallVector{1, i}); + } + + rewriter.replaceOp(op, res); + + return success(); + } +}; + +struct LoadOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Type ptrTy = LLVM::LLVMPointerType::get(getContext()); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, ptr, 0, + op.getIsVolatile()); + return success(); + } +}; + +struct StoreOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value val = rewriter.getRemappedValue(op.getValue()); + rewriter.replaceOpWithNewOp(op, val, ptr); + return success(); + } +}; + +struct PtrToIntOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = rewriter.getRemappedValue(op.getSrc()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } +}; + +struct IntToPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = rewriter.getRemappedValue(op.getSrc()); + Type resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } +}; + +struct AddPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Expect only scalar pointers here. + assert(isa(op.getType())); + auto ptrTy = cast(op.getPtr().getType()); + Type elemTy = getTypeConverter()->convertType(ptrTy.getPointeeType()); + Type resTy = getTypeConverter()->convertType(ptrTy); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value offset = rewriter.getRemappedValue(op.getOffset()); + rewriter.replaceOpWithNewOp(op, resTy, elemTy, ptr, offset); + return success(); + } +}; + +struct PtrBitcastConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // By this moment we expect tt.bitcast used only for scalar pointer casts. + // This cast becomes NOP for LLVM dialect, so simply return the source arg. + assert(isa(op.getType())); + assert(isa(op.getSrc().getType())); + Value src = rewriter.getRemappedValue(op.getSrc()); + rewriter.replaceOp(op, src); + return success(); + } +}; + +struct PtrSelectConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // By this moment we expect tt.bitcast used only for scalar pointer casts. + // This cast becomes NOP for LLVM dialect, so simply return the source arg. + if (!isa(op.getType())) + return failure(); + + Value trueVal = rewriter.getRemappedValue(op.getTrueValue()); + Value falseVal = rewriter.getRemappedValue(op.getFalseValue()); + Value cond = rewriter.getRemappedValue(op.getCondition()); + rewriter.replaceOpWithNewOp(op, cond, trueVal, falseVal); + return success(); + } +}; + +struct MemoryOpToLLVM + : public triton::impl::MemoryOpToLLVMBase { + using MemoryOpToLLVMBase::MemoryOpToLLVMBase; + + MemoryOpToLLVM() : MemoryOpToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createMemoryOpToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp new file mode 100644 index 000000000000..914f56e668f8 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp @@ -0,0 +1,25 @@ +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +namespace triton { +namespace cpu { + +void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) { + pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); + pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); + pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); + // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); +} + +void registerTritonCPUToLLVMPipeline() { + PassPipelineRegistration<>("triton-cpu-to-llvmir", + "TritonCPU to LLVM conversion pipeline.", + tritonCPUToLLVMPipelineBuilder); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp new file mode 100644 index 000000000000..144cb57b1115 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp @@ -0,0 +1,43 @@ +#include "TypeConverter.h" + +using namespace mlir; +using namespace mlir::triton; + +TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( + MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, option, analysis) { + addConversion([&](triton::PointerType type) -> std::optional { + return convertTritonPointerType(type); + }); + addConversion([this](RankedTensorType tensorTy) -> std::optional { + if (isa(tensorTy.getElementType())) + return VectorType::get(tensorTy.getShape(), + IntegerType::get(tensorTy.getContext(), 64)); + return std::nullopt; + }); +} + +Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( + triton::PointerType type) { + auto ctx = type.getContext(); + auto pointeeType = type.getPointeeType(); + if (isa(pointeeType)) { + // struct { + // ptr base_ptr; + // array offsets; + // array shape; + // array strides; + // } + auto tensorTy = cast(pointeeType); + auto rank = tensorTy.getShape().size(); + auto i64Ty = IntegerType::get(ctx, 64); + SmallVector types; + types.push_back(LLVM::LLVMPointerType::get(ctx)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + types.push_back(LLVM::LLVMArrayType::get(ctx, i64Ty, rank)); + return LLVM::LLVMStructType::getLiteral(ctx, types); + } + return LLVM::LLVMPointerType::get(ctx); +} diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h new file mode 100644 index 000000000000..35d74a9ec430 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h @@ -0,0 +1,22 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { +public: + using TypeConverter::convertType; + + TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, + const DataLayoutAnalysis *analysis = nullptr); + + Type convertTritonPointerType(triton::PointerType type); +}; + +#endif diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt new file mode 100644 index 000000000000..997fb748878a --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -0,0 +1,16 @@ +add_triton_library(TritonToTritonCPU + ConvertControlFlowOps.cpp + ConvertDotOp.cpp + ConvertElementwiseOps.cpp + ConvertMemoryOps.cpp + ConvertPtrOps.cpp + Pipeline.cpp + TypeConverter.cpp + + DEPENDS + TritonToTritonCPUPassIncGen + + LINK_LIBS PUBLIC + TritonCPUIR + MLIRVectorDialect +) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp new file mode 100644 index 000000000000..9cf6e31810d7 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp @@ -0,0 +1,121 @@ +#include "OpTypeConversion.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTCONTROLFLOWOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ControlFlowOpConversionTarget : public ConversionTarget { +public: + explicit ControlFlowOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + } +}; + +struct ForOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value lowerBound = rewriter.getRemappedValue(op.getLowerBound()); + Value upperBound = rewriter.getRemappedValue(op.getUpperBound()); + Value step = rewriter.getRemappedValue(op.getStep()); + SmallVector initArgs; + if (failed(rewriter.getRemappedValues(op.getInitArgs(), initArgs))) + return failure(); + // Create new for op with remapped values. + auto newOp = rewriter.create(op.getLoc(), lowerBound, + upperBound, step, initArgs); + // Move the old op block and convert its sigature. + Block *oldBlock = op.getBody(); + Block *newBlock = newOp.getBody(); + rewriter.moveBlockBefore(oldBlock, newOp.getBody()); + rewriter.eraseBlock(newBlock); + if (failed(rewriter.convertRegionTypes(oldBlock->getParent(), + *getTypeConverter()))) + return failure(); + rewriter.replaceOp(op, newOp); + + return success(); + } +}; + +struct ConvertControlFlowOps + : public triton::impl::ConvertControlFlowOpsBase { + using ConvertControlFlowOpsBase::ConvertControlFlowOpsBase; + + ConvertControlFlowOps() : ConvertControlFlowOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ControlFlowOpConversionTarget convTarget(*context, typeConverter); + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + { + RewritePatternSet patterns(context); + patterns.add>(typeConverter, context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + { + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertControlFlowOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp new file mode 100644 index 000000000000..51a5f42fa63a --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp @@ -0,0 +1,102 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTDOTOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class DotConversionTarget : public ConversionTarget { +public: + explicit DotConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + } +}; + +struct DotOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Value a = rewriter.getRemappedValue(op.getA()); + Value b = rewriter.getRemappedValue(op.getB()); + Value c = rewriter.getRemappedValue(op.getC()); + auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); + auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); + auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); + auto iteratorTypes = rewriter.getArrayAttr( + {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::reduction)}); + rewriter.replaceOpWithNewOp( + op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), + iteratorTypes); + return success(); + } +}; + +struct ConvertDotOp : public triton::impl::ConvertDotOpBase { + using ConvertDotOpBase::ConvertDotOpBase; + + ConvertDotOp() : ConvertDotOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + DotConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp new file mode 100644 index 000000000000..218dd827619a --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -0,0 +1,341 @@ +#include "OpTypeConversion.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTELEMENTWISEOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ElementwiseOpConversionTarget : public ConversionTarget { +public: + explicit ElementwiseOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addDynamicallyLegalDialect( + [&](Operation *op) -> std::optional { + return converter.isLegal(op); + }); + addDynamicallyLegalDialect( + [&](Operation *op) -> std::optional { + return converter.isLegal(op); + }); + + addDynamicallyLegalOp( + [](triton::BitcastOp op) { return isa(op.getType()); }); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + } +}; + +struct ConstantOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + assert(resTy); + if (auto denseAttr = dyn_cast(op.getValueAttr())) { + rewriter.replaceOpWithNewOp(op, resTy, + denseAttr.reshape(resTy)); + } else { + llvm_unreachable("Unexpected constant attribute"); + } + return success(); + } +}; + +struct ReshapeOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcShape = dyn_cast(src.getType()).getShape(); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto dstShape = resTy.getShape(); + auto elemTy = resTy.getElementType(); + + // There are restrictions on how shape can be modified by ShapeCastOp + // when rank is changed. For now, we simply detect it and handle through + // a cast to 1D vector. Better solution may be required later. + if (canCastShape(srcShape, dstShape)) { + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), src); + } else { + SmallVector tmpShape({resTy.getNumElements()}); + auto tmp = rewriter.create( + loc, VectorType::get(tmpShape, elemTy), src); + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), tmp); + } + return success(); + } + +private: + bool canCastShape(ArrayRef src, ArrayRef dst) const { + if (src.size() == dst.size()) + return true; + if (src.size() > dst.size()) + return canCastShape(dst, src); + + size_t srcIdx = 0; + size_t dstIdx = 0; + while (srcIdx < src.size() && dstIdx < dst.size()) { + if (src[srcIdx] == 1) { + ++srcIdx; + } else { + // Source dim size should be a product of continuous dest dim sizes. + int64_t srcSize = src[srcIdx++]; + int64_t dstSize = dst[dstIdx++]; + while (dstSize < srcSize && dstIdx < dst.size()) + dstSize *= dst[dstIdx++]; + if (dstSize != srcSize) + return false; + } + } + + // Skip trailing 1s. + while (srcIdx < src.size() && src[srcIdx] == 1) + ++srcIdx; + while (dstIdx < dst.size() && dst[dstIdx] == 1) + ++dstIdx; + + return srcIdx == src.size() && dstIdx == dst.size(); + } +}; + +struct MulhiUIOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MulhiUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getX()); + auto rhs = rewriter.getRemappedValue(op.getY()); + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + auto vecI32Ty = lhsTy.cloneWith(std::nullopt, rewriter.getI32Type()); + auto vecI64Ty = lhsTy.cloneWith(std::nullopt, rewriter.getI64Type()); + assert(lhsTy.getElementType().isInteger()); + assert(rhsTy.getElementType().isInteger()); + // Cast to int64 + if (lhsTy.getElementTypeBitWidth() < 64) { + lhs = rewriter.create(loc, vecI64Ty, lhs); + } + if (rhsTy.getElementTypeBitWidth() < 64) { + rhs = rewriter.create(loc, vecI64Ty, rhs); + } + Value res = rewriter.create(loc, lhs, rhs); + Value cst32 = rewriter.create( + loc, DenseElementsAttr::get(vecI64Ty, 32LL)); + res = rewriter.create(loc, res, cst32); + res = rewriter.create(loc, vecI32Ty, res); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ClampFOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto val = rewriter.getRemappedValue(op.getX()); + auto minVal = rewriter.getRemappedValue(op.getMin()); + auto maxVal = rewriter.getRemappedValue(op.getMax()); + Value res; + if (op.getPropagateNanAttr().getValue() == PropagateNan::ALL) { + res = rewriter.create(loc, val, minVal); + res = rewriter.create(loc, res, maxVal); + } else { + res = rewriter.create(loc, val, minVal); + res = rewriter.create(loc, res, maxVal); + } + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct TransOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto val = rewriter.getRemappedValue(op.getSrc()); + auto order = op.getOrder(); + SmallVector permutation(order.begin(), order.end()); + rewriter.replaceOpWithNewOp(op, val, permutation); + return success(); + } +}; + +struct JoinOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto interleave = rewriter.create(loc, lhs, rhs); + // JoinOp creates a new dimension, but InterleaveOp doubles the final one. + // Use ShapeCastOp to get the required shape. + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, interleave); + return success(); + } +}; + +struct ConvertElementwiseOps + : public triton::impl::ConvertElementwiseOpsBase { + using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; + + ConvertElementwiseOps() : ConvertElementwiseOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ElementwiseOpConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + + patterns.add(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertElementwiseOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp new file mode 100644 index 000000000000..2787247a731c --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -0,0 +1,456 @@ +#include "TypeConverter.h" + +#include "cpu/include/Analysis/TensorPtrShapeInfo.h" +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTMEMORYOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +template +struct MemoryOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getContext; + using OpConversionPattern::getTypeConverter; + + MemoryOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, + ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, + TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context), + axisAnalysis(axisInfoAnalysis), shapeAnalysis(shapeInfoAnalysis) {} + + Value extractScalarPointer(Location loc, Value ptrs, + ArrayRef indices, + ConversionPatternRewriter &rewriter) const { + // TODO: Analyze data flow and build scalar pointer computation code. + Value ptr = rewriter.create( + loc, rewriter.getRemappedValue(ptrs), indices); + auto ptrTy = dyn_cast(ptrs.getType()).getElementType(); + ptr = rewriter.create(loc, ptrTy, ptr); + return ptr; + } + + Value extractMemRef(Location loc, Value ptr, + ConversionPatternRewriter &rewriter) const { + auto tensorTy = dyn_cast( + dyn_cast(ptr.getType()).getPointeeType()); + auto elemTy = tensorTy.getElementType(); + auto shapeInfo = shapeAnalysis.getPtrShapeInfo(ptr); + Type memRefTy; + if (shapeInfo && shapeInfo->getRank() > 0) { + auto layout = + StridedLayoutAttr::get(getContext(), 0, shapeInfo->getStrides()); + memRefTy = MemRefType::get(shapeInfo->getShape(), elemTy, layout); + } else { + SmallVector dynVals(tensorTy.getRank(), ShapedType::kDynamic); + auto layout = StridedLayoutAttr::get(getContext(), 0, dynVals); + memRefTy = MemRefType::get(dynVals, elemTy, layout); + } + return rewriter.create(loc, memRefTy, ptr); + } + + Value convertOtherVal(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + if (loadOp.getOther()) + return rewriter.getRemappedValue(loadOp.getOther()); + + auto resTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + return rewriter.create( + loadOp.getLoc(), resTy, + SplatElementsAttr::get(resTy, + rewriter.getZeroAttr(resTy.getElementType()))); + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysis; + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; +}; + +struct LoadOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; + + LogicalResult + matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp.getLoc(); + auto mask = loadOp.getMask(); + auto ptr = loadOp.getPtr(); + auto boundaryChecks = loadOp.getBoundaryCheck(); + + if (!triton::isTensorPointerType(ptr.getType())) { + auto axisInfo = axisAnalysis.getAxisInfo(ptr); + if (axisInfo) { + return lowerUsingAxisInfo(axisInfo, loadOp, rewriter); + } + return lowerToScalarLoads(loadOp, rewriter); + } + + // TODO: support masks. + if (mask) { + llvm_unreachable("unsupported load op"); + } + + auto memRef = extractMemRef(loc, ptr, rewriter); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto resTy = dyn_cast( + getTypeConverter()->convertType(loadOp.getResult().getType())); + auto indices = rewriter.create(loc, ptr).getResults(); + SmallVector inBounds(rank, true); + for (auto dim : boundaryChecks) { + inBounds[dim] = false; + } + auto vecRead = rewriter.create(loc, resTy, memRef, + indices, inBounds); + rewriter.replaceOp(loadOp, vecRead); + return success(); + } + + LogicalResult lowerUsingAxisInfo(AxisInfo *axisInfo, triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + // This is an experimental code that covers only a simple case of axis info + // usage to demostrate load by tensor of pointers transformation into vector + // loads. + // TODO: Support more cases. + // TODO: Make separate pass to produce block pointer stores? + auto loc = loadOp.getLoc(); + auto vecTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + auto shape = vecTy.getShape(); + auto contiguity = axisInfo->getContiguity(); + if (shape.back() > 1 && shape.back() == contiguity.back()) { + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type subVecTy = VectorType::get(shape.back(), vecTy.getElementType()); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = loadOp.getMask() + ? rewriter.getRemappedValue(loadOp.getMask()) + : nullptr; + Value zeroIdx = rewriter.create(loc, 0); + Value defaultVal = convertOtherVal(loadOp, rewriter); + Value res = defaultVal; + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + SmallVector subIndices(indices.begin(), + indices.begin() + indices.size() - 1); + auto ptr = + extractScalarPointer(loc, loadOp.getPtr(), indices, rewriter); + Value memRef = + rewriter.create(loc, memRefTy, ptr); + Value vec; + if (mask) { + Value subMask = mask; + Value passThru = defaultVal; + if (shape.size() > 1) { + subMask = rewriter.create(loc, mask, subIndices); + passThru = + rewriter.create(loc, defaultVal, subIndices); + } + vec = rewriter.create( + loc, subVecTy, memRef, zeroIdx, subMask, passThru); + } else { + vec = rewriter.create(loc, subVecTy, memRef, zeroIdx); + } + + if (shape.size() > 1) { + res = rewriter.create(loc, vec, res, subIndices); + } else { + res = vec; + } + } + + rewriter.replaceOp(loadOp, res); + return success(); + } + + return lowerToScalarLoads(loadOp, rewriter); + } + + LogicalResult lowerToScalarLoads(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + // Scalar loads and boundary checks are not expected. + assert(loadOp.getBoundaryCheck().empty()); + assert(isa(loadOp.getType())); + + auto loc = loadOp.getLoc(); + auto vecTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + auto ptrs = rewriter.getRemappedValue(loadOp.getPtr()); + auto mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) + : nullptr; + auto ptrTy = + dyn_cast(loadOp.getPtr().getType()).getElementType(); + auto cache = loadOp.getCache(); + auto evict = loadOp.getEvict(); + auto isVolatile = loadOp.getIsVolatile(); + Value dst = convertOtherVal(loadOp, rewriter); + int64_t numElems = vecTy.getNumElements(); + auto strides = computeStrides(vecTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Block *headerBlock = rewriter.getBlock(); + Block *condBlock = nullptr; + Value origDst = dst; + // Create a conditional block for load if there is a mask. + if (mask) { + condBlock = + rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(condBlock); + } + + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = + rewriter.create(loc, ptr, cache, evict, isVolatile); + dst = rewriter.create(loc, val, dst, indices); + + // Add predicate and branches. + if (mask) { + Block *footerBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + Value resDst = dst; + dst = footerBlock->addArgument(dst.getType(), dst.getLoc()); + rewriter.setInsertionPointToEnd(headerBlock); + auto predicate = rewriter.create(loc, mask, indices); + rewriter.create(loc, predicate, condBlock, + footerBlock, origDst); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, footerBlock, resDst); + rewriter.setInsertionPointToStart(footerBlock); + } + } + + rewriter.replaceOp(loadOp, dst); + + return success(); + } +}; + +struct StoreOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; + + LogicalResult + matchAndRewrite(triton::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp.getLoc(); + auto mask = storeOp.getMask(); + auto ptr = storeOp.getPtr(); + auto boundaryChecks = storeOp.getBoundaryCheck(); + + if (!triton::isTensorPointerType(ptr.getType())) { + auto axisInfo = axisAnalysis.getAxisInfo(ptr); + if (axisInfo) { + return lowerUsingAxisInfo(axisInfo, storeOp, rewriter); + } + return lowerToScalarStores(storeOp, rewriter); + } + + // TODO: support masks. + if (mask) { + llvm_unreachable("unsupported store op"); + } + + auto value = rewriter.getRemappedValue(storeOp.getValue()); + auto memRef = extractMemRef(loc, ptr, rewriter); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto indices = rewriter.create(loc, ptr).getResults(); + SmallVector inBounds(rank, true); + for (auto dim : boundaryChecks) { + inBounds[dim] = false; + } + auto vecWrite = rewriter.create(loc, value, memRef, + indices, inBounds); + rewriter.replaceOp(storeOp, vecWrite); + return success(); + } + + LogicalResult lowerUsingAxisInfo(AxisInfo *axisInfo, triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + // This is an experimental code that covers only a simple case of axis info + // usage to demostrate load by tensor of pointers transformation into vector + // loads. + // TODO: Support more cases. + // TODO: Make separate pass to produce block pointer stores instead? + auto loc = storeOp.getLoc(); + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + auto vecTy = dyn_cast(vals.getType()); + auto shape = vecTy.getShape(); + auto contiguity = axisInfo->getContiguity(); + if (shape.back() > 1 && shape.back() == contiguity.back()) { + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = storeOp.getMask() + ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + Value zeroIdx = rewriter.create(loc, 0); + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + auto ptr = + extractScalarPointer(loc, storeOp.getPtr(), indices, rewriter); + Value memRef = + rewriter.create(loc, memRefTy, ptr); + indices.pop_back(); + auto val = rewriter.create(loc, vals, indices); + + if (mask) { + Value subMask = mask; + if (shape.size() > 1) { + SmallVector subIndices = indices; + subIndices.pop_back(); + subMask = rewriter.create(loc, mask, indices); + } + rewriter.create(loc, memRef, zeroIdx, subMask, + val); + } else { + rewriter.create(loc, val, memRef, zeroIdx); + } + } + + rewriter.eraseOp(storeOp); + return success(); + } + + return lowerToScalarStores(storeOp, rewriter); + } + + LogicalResult lowerToScalarStores(triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + // Scalar stores and boundary checks are not expected. + assert(storeOp.getBoundaryCheck().empty()); + assert(isa(storeOp.getValue().getType())); + + auto loc = storeOp.getLoc(); + auto ptrs = rewriter.getRemappedValue(storeOp.getPtr()); + auto mask = storeOp.getMask() ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + auto tensorTy = dyn_cast(storeOp.getPtr().getType()); + auto ptrTy = tensorTy.getElementType(); + auto cache = storeOp.getCache(); + auto evict = storeOp.getEvict(); + + int64_t numElems = tensorTy.getNumElements(); + auto strides = computeStrides(tensorTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Block *headerBlock = rewriter.getBlock(); + Block *condBlock = nullptr; + // Create a conditional block for store if there is a mask. + if (mask) { + condBlock = + rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(condBlock); + } + + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + rewriter.create(loc, ptr, val, cache, evict); + + // Add predicate and branches. + if (mask) { + Block *footerBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(headerBlock); + auto predicate = rewriter.create(loc, mask, indices); + rewriter.create(loc, predicate, condBlock, + footerBlock); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, footerBlock); + rewriter.setInsertionPointToStart(footerBlock); + } + } + + rewriter.eraseOp(storeOp); + + return success(); + } +}; + +class MemoryOpConversionTarget : public ConversionTarget { +public: + explicit MemoryOpConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + // Allow only scalar loads and stores. + addDynamicallyLegalOp([](triton::LoadOp loadOp) { + return loadOp.getType().isIntOrIndexOrFloat(); + }); + addDynamicallyLegalOp([](triton::StoreOp storeOp) { + return storeOp.getValue().getType().isIntOrIndexOrFloat(); + }); + } +}; + +struct ConvertMemoryOps + : public triton::impl::ConvertMemoryOpsBase { + using ConvertMemoryOpsBase::ConvertMemoryOpsBase; + + ConvertMemoryOps() : ConvertMemoryOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + ModuleTensorPtrShapeInfoAnalysis shapeInfoAnalysis(mod); + MemoryOpConversionTarget convTarget(*context); + TritonToTritonCPUTypeConverter pointerConverter; + RewritePatternSet patterns(context); + patterns.add(axisInfoAnalysis, shapeInfoAnalysis, + pointerConverter, context); + patterns.add(axisInfoAnalysis, shapeInfoAnalysis, + pointerConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertMemoryOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp new file mode 100644 index 000000000000..82123c376dc1 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp @@ -0,0 +1,195 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTPTROPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +unsigned getElemBitWidth(Type type) { + if (auto tensorTy = dyn_cast(type)) + return tensorTy.getElementType().getIntOrFloatBitWidth(); + if (auto vectorTy = dyn_cast(type)) + return vectorTy.getElementType().getIntOrFloatBitWidth(); + return type.getIntOrFloatBitWidth(); +} + +class PtrConversionTarget : public ConversionTarget { +public: + explicit PtrConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + // Scalar pointer operations are translated directly to LLVM. + addDynamicallyLegalOp( + [](triton::PtrToIntOp op) { return op.getType().isInteger(); }); + addDynamicallyLegalOp([](triton::IntToPtrOp op) { + return op.getSrc().getType().isInteger(); + }); + addDynamicallyLegalOp( + [](triton::AddPtrOp op) { return isa(op.getType()); }); + } +}; + +struct MakeRangeOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int32_t start = static_cast(op.getStart()); + int32_t end = static_cast(op.getEnd()); + assert(end >= start); + + llvm::SmallVector values; + values.reserve(end - start); + for (int32_t v = start; v < end; ++v) { + values.push_back(v); + } + + Type resTy = getTypeConverter()->convertType(op.getType()); + auto newOp = rewriter.create( + op.getLoc(), resTy, rewriter.getI32VectorAttr(values)); + + rewriter.replaceOp(op, newOp); + return success(); + } +}; + +struct SplatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value val = op.getSrc(); + // Cast pointer + if (isa(val.getType())) + val = rewriter.create(loc, rewriter.getI64Type(), val) + .getResult(); + Type resType = getTypeConverter()->convertType(op.getType()); + auto cast = rewriter.create(loc, resType, val); + + rewriter.replaceOp(op, cast); + return success(); + } +}; + +struct AddPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value ptr = rewriter.getRemappedValue(op.getPtr()); + Value offset = rewriter.getRemappedValue(op.getOffset()); + unsigned offsetBitWidth = getElemBitWidth(offset.getType()); + unsigned elemBitWidth = getPointeeBitWidth(op.getPtr().getType()); + // Scalar case is not expected. + assert(isa(offset.getType())); + assert(isa(ptr.getType())); + VectorType offsetTy = cast(offset.getType()); + // Build scale vector. i1 elements take 1 byte. + Value scale = rewriter.create( + loc, offsetTy, + SplatElementsAttr::get( + offsetTy, rewriter.getIntegerAttr(offsetTy.getElementType(), + (elemBitWidth + 7) / 8))); + offset = rewriter.create(loc, offset, scale); + offset = rewriter.create(loc, ptr.getType(), offset); + rewriter.replaceOpWithNewOp(op, ptr.getType(), ptr, offset); + return success(); + } +}; + +struct PtrToIntOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value val = rewriter.getRemappedValue(op.getSrc()); + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, val); + return success(); + } +}; + +struct IntToPtrOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::IntToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value val = rewriter.getRemappedValue(op.getSrc()); + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, val); + return success(); + } +}; + +struct ConvertPtrOps : public triton::impl::ConvertPtrOpsBase { + using ConvertPtrOpsBase::ConvertPtrOpsBase; + + ConvertPtrOps() : ConvertPtrOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + PtrConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertPtrOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h b/third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h new file mode 100644 index 000000000000..aaac6a27d5e6 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/OpTypeConversion.h @@ -0,0 +1,37 @@ +#include "mlir/IR/OperationSupport.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +// Generic pattern to rewrite operation by converting types +// for operation operands and results using provided type +// converter. +template +struct OpTypeConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + using typename OpConversionPattern::OpAdaptor; + + LogicalResult + matchAndRewrite(OpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + OperationState newState(op.getLoc(), ResOpT::getOperationName()); + // Convert operands. + for (auto operand : op->getOperands()) { + Value newOperand = rewriter.getRemappedValue(operand); + newState.operands.push_back(newOperand); + } + // Convert result types. + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + newState.types))) { + return failure(); + } + newState.attributes = op->getAttrs(); + + auto newOp = rewriter.create(newState); + rewriter.replaceOp(op, newOp); + + return success(); + } +}; diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp new file mode 100644 index 000000000000..d954142d9172 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -0,0 +1,27 @@ +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +namespace triton { +namespace cpu { + +void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); + pm.addPass(mlir::triton::cpu::createConvertPtrOps()); + pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); + pm.addPass(mlir::triton::cpu::createConvertDotOp()); + pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); + // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); +} + +void registerTritonToTritonCPUPipeline() { + PassPipelineRegistration<>("triton-to-triton-cpu", + "Triton to TritonCPU conversion pipeline.", + tritonToTritonCPUPipelineBuilder); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp new file mode 100644 index 000000000000..ce66f8faeb3e --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp @@ -0,0 +1,34 @@ +#include "TypeConverter.h" + +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { + addConversion([](Type type) { return type; }); + addConversion([this](RankedTensorType tensorTy) -> Type { + Type elemTy = convertType(tensorTy.getElementType()); + if (isa(elemTy)) + elemTy = IntegerType::get(tensorTy.getContext(), 64); + return VectorType::get(tensorTy.getShape(), elemTy); + }); + + // Converted ops produce vectors instead of tensors. Provide conversion + // here for users. + addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> std::optional { + return builder.create(loc, type, inputs) + .getResult(0); + }); + + // Provide conversion for vector users. + addTargetMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, + Location loc) -> std::optional { + if (isa(type)) + return builder.create(loc, type, inputs) + .getResult(0); + llvm_unreachable("Unexpected target materizalization"); + }); +} diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h new file mode 100644 index 000000000000..cb89f0886c60 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.h @@ -0,0 +1,19 @@ +#ifndef TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H +#define TRITON_CONVERSION_TRITON_TO_TRITONCPU_TYPECONVERTER_H + +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using namespace mlir; +using namespace mlir::triton; + +class TritonToTritonCPUTypeConverter : public TypeConverter { +public: + using TypeConverter::convertType; + + TritonToTritonCPUTypeConverter(); + + Type convertTritonPointerType(triton::PointerType type); +}; + +#endif diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 302951d04d59..8065098becbe 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -1,9 +1,20 @@ +#include "TritonCPUToLLVM/Passes.h" +#include "TritonToTritonCPU/Passes.h" + +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/Passes.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "triton/Conversion/TritonCPUToLLVM/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "llvm/IR/Constants.h" #include "llvm/Support/TargetSelect.h" + #include #include #include @@ -14,8 +25,37 @@ namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { using namespace mlir::triton; - m.def("add_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); + // m.def("add_to_llvmir", [](mlir::PassManager &pm) { + // pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); + // }); + m.def("add_triton_to_triton_cpu_pipeline", [](mlir::PassManager &pm) { + mlir::triton::cpu::tritonToTritonCPUPipelineBuilder(pm); + }); + m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { + mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); + }); + m.def("add_vector_to_scf", [](mlir::PassManager &pm, bool full_unroll, + unsigned target_rank, bool lower_tensors) { + mlir::VectorTransferToSCFOptions opts; + opts.setTargetRank(target_rank); + opts.enableFullUnroll(full_unroll); + opts.enableLowerTensors(lower_tensors); + pm.addPass(mlir::createConvertVectorToSCFPass(opts)); + }); + m.def("add_vector_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertVectorToLLVMPass()); + }); + m.def("add_lower_affine", [](mlir::PassManager &pm) { + pm.addPass(mlir::createLowerAffinePass()); + }); + m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); + }); + m.def("add_math_to_libm", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertMathToLibmPass()); + }); + m.def("add_func_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::createConvertFuncToLLVMPass()); }); } @@ -25,8 +65,18 @@ void init_triton_cpu(py::module &&m) { m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); + + m.def("find_kernel_names", [](mlir::ModuleOp &mod) { + std::vector res; + mod.walk([&](mlir::FunctionOpInterface funcOp) { + if (funcOp.getVisibility() == mlir::SymbolTable::Visibility::Public) + res.push_back(funcOp.getName().str()); + }); + return res; + }); } From e1dd10e398844c535830dc5cc2cc194d56a07adc Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 28 May 2024 16:02:14 -0500 Subject: [PATCH 013/165] Add support for tl.cat operation. (#9) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 1 + .../ConvertElementwiseOps.cpp | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 697df74ab7be..fd23d610739f 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1920,6 +1920,7 @@ def kernel(X, Z, TO_TYPE: tl.constexpr, BITCAST: tl.constexpr, SIZE: tl.constexp np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str, num_warps", [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 218dd827619a..cadec818910b 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -60,6 +60,7 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -236,6 +237,24 @@ struct JoinOpConversion : public OpConversionPattern { } }; +struct CatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + SmallVector indices(lhsTy.getShape()[0] + rhsTy.getShape()[0]); + std::iota(indices.begin(), indices.end(), 0); + rewriter.replaceOpWithNewOp(op, lhs, rhs, indices); + return success(); + } +}; + struct ConvertElementwiseOps : public triton::impl::ConvertElementwiseOpsBase { using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; @@ -320,6 +339,7 @@ struct ConvertElementwiseOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); From 7b183cf49ea8865f3951e13161fc3947cfc07f94 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Tue, 28 May 2024 14:49:37 -0700 Subject: [PATCH 014/165] [BACKEND][CPU] Make it buildable and runnable in a different environment (#8) * [BACKEND][CPU] Make it buildable and runnable in a different environment * Revert seemingly inconsistent python code formatting --- include/triton/Conversion/CMakeLists.txt | 5 +-- lib/Conversion/CMakeLists.txt | 5 +-- python/src/passes.cc | 3 -- python/triton/runtime/build.py | 34 ++++++++++++++++++- third_party/cpu/backend/compiler.py | 2 +- third_party/cpu/backend/driver.py | 4 ++- .../include/TritonCPUToLLVM/CMakeLists.txt | 2 +- 7 files changed, 44 insertions(+), 11 deletions(-) diff --git a/include/triton/Conversion/CMakeLists.txt b/include/triton/Conversion/CMakeLists.txt index ae31ac930b7e..3b8a95e1ecf7 100644 --- a/include/triton/Conversion/CMakeLists.txt +++ b/include/triton/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ -add_subdirectory(TritonCPUToLLVM) +# TODO(minjang): I will remove these scratches soon. +# add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) -add_subdirectory(TritonToTritonCPU) +# add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 83db4ae41607..426b22a42ef6 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,5 @@ -#add_subdirectory(TritonToTritonCPU) +# TODO(minjang): I will remove these scratches soon. +# add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) -#add_subdirectory(TritonCPUToLLVM) +# add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/python/src/passes.cc b/python/src/passes.cc index 9e34f6ad7fed..f9dbfe16bc64 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -6,7 +6,6 @@ #include "triton/Analysis/Allocation.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" -#include "triton/Conversion/TritonToTritonCPU/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" #include "triton/Dialect/Triton/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -45,8 +44,6 @@ void init_triton_passes_ttir(py::module &&m) { ADD_PASS_WRAPPER_4("add_convert_to_ttgpuir", createConvertTritonToTritonGPUPass, const std::string &, int, int, int); - // ADD_PASS_WRAPPER_0("add_convert_to_ttcpuir", - // createConvertTritonToTritonCPUPass); } void init_triton_passes_ttgpuir(py::module &&m) { diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 1b76548d43a7..19f74254dd6b 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -33,5 +33,37 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] - subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL) + # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. + if src.endswith(".cpp") or src.endswith(".cc"): + cc_cmd += ["-std=c++17"] + ret = subprocess.check_call(cc_cmd) + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) return so diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 344cdd2f05ae..3daf83eaac28 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -20,7 +20,7 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False - allowed_dot_input_precisions: Tuple[str] = ("ieee",) + allowed_dot_input_precisions: Tuple[str] = ("ieee", ) allow_fp8e4nv: bool = False enable_fp_fusion: bool = True diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 3fe243fc262d..5783a0342dbd 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -74,6 +74,7 @@ "LLVMSupport", "LLVMDemangle", "stdc++", + "z", ] @@ -176,7 +177,8 @@ def format_of(ty): arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' kernel_fn_args = [i for i in signature.keys() if i not in constants] kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else '' - kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" + kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + + ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" # generate glue code src = f""" diff --git a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt index 64b36523d35d..0936dff12d91 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/include/TritonCPUToLLVM/CMakeLists.txt @@ -1,3 +1,3 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) -add_public_tablegen_target(TritonCPUConversionPassIncGen) +add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen) From d6df9c186541e22ae100e0273ca1c1924279f7bb Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 29 May 2024 15:19:16 -0500 Subject: [PATCH 015/165] Add support for simple reductions. (#10) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 4 + third_party/cpu/backend/compiler.py | 1 + .../cpu/include/TritonCPUToLLVM/Passes.h | 3 + .../cpu/include/TritonCPUToLLVM/Passes.td | 11 + .../cpu/include/TritonToTritonCPU/Passes.h | 1 + .../cpu/include/TritonToTritonCPU/Passes.td | 14 ++ .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 1 + .../TritonCPUToLLVM/LowerMultiReduction.cpp | 59 +++++ .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 1 + .../TritonToTritonCPU/ConvertReductionOp.cpp | 218 ++++++++++++++++++ .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 1 + third_party/cpu/triton_cpu.cc | 4 + 12 files changed, 318 insertions(+) create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index fd23d610739f..e04d99612dfa 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2379,6 +2379,7 @@ def kernel(X, Z, BLOCK: tl.constexpr): reduce_bool = [(op, 'bool', shape, axis, False) for op in ['xor_sum'] for shape in reduce2d_shapes for axis in [0, 1]] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_str, shape, axis, keep_dims", reduce_configs1 + reduce_configs2 + reduce_configs3 + invalid_config + @@ -2387,6 +2388,9 @@ def kernel(X, Z, BLOCK: tl.constexpr): def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + if is_cpu() and op in ('argmin', 'argmax'): + pytest.skip(f"Not yet implemented on CPU: {op}") + @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr, USE_I1: tl.constexpr): diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 3daf83eaac28..d48fbf3a96bf 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -98,6 +98,7 @@ def make_llir(src, metadata, options): # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + cpu.passes.ttcpuir.add_lower_vector_multi_dim(pm) cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False) cpu.passes.ttcpuir.add_lower_affine(pm) passes.convert.add_scf_to_cf(pm) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index 74f74b00870c..a1fbce2e4892 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -5,6 +5,8 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + #include namespace mlir { @@ -21,6 +23,7 @@ namespace cpu { std::unique_ptr> createFuncOpToLLVMPass(); std::unique_ptr> createMemoryOpToLLVMPass(); std::unique_ptr> createGetProgramIdOpToLLVMPass(); +std::unique_ptr> createLowerMultiReductionPass(); void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); void registerTritonCPUToLLVMPipeline(); diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index c75b58b572f1..2abe88338dcf 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -43,4 +43,15 @@ def GetProgramIdOpToLLVM : Pass<"triton-cpu-get-program-id-op-to-llvm", "mlir::M "mlir::triton::TritonDialect"]; } +def LowerMultiReduction : Pass<"triton-cpu-lower-multi-reduction", "mlir::triton::FuncOp"> { + let summary = "Convert multi-dimensional reductions."; + let description = [{ + }]; + let constructor = "mlir::triton::cpu::createLowerMultiReductionPass()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 745799039691..5893c99f250e 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -23,6 +23,7 @@ std::unique_ptr> createConvertMemoryOps(); std::unique_ptr> createConvertPtrOps(); std::unique_ptr> createConvertDotOp(); std::unique_ptr> createConvertControlFlowOps(); +std::unique_ptr> createConvertReductionOp(); void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); void registerTritonToTritonCPUPipeline(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 5f52f3a2e31d..a2663bea5589 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -74,4 +74,18 @@ def ConvertControlFlowOps : Pass<"triton-cpu-convert-control-flow-op", "mlir::Mo "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> { + let summary = "Convert Triton ReduceOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertReductionOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt index 884c9352ef1b..0cf83bc03b06 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonCPUToLLVM FuncOpToLLVM.cpp GetProgramIdOpToLLVM.cpp + LowerMultiReduction.cpp MemoryOpToLLVM.cpp Pipeline.cpp TypeConverter.cpp diff --git a/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp b/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp new file mode 100644 index 000000000000..5c18f15d9d1b --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp @@ -0,0 +1,59 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_LOWERMULTIREDUCTION +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// This pass exists because LowerVectorMultiReductionPass can be run on +// func::FuncOp only and we translate triton::FuncOp directly into llvm::FuncOp. +// So we run the same set of patterns on triton::FuncOp. +struct LowerMultiReduction + : public mlir::triton::impl::LowerMultiReductionBase { + using LowerMultiReductionBase::LowerMultiReductionBase; + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + + RewritePatternSet loweringPatterns(context); + vector::VectorMultiReductionLowering options; + vector::populateVectorMultiReductionLoweringPatterns(loweringPatterns, + options); + + if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) + signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createLowerMultiReductionPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index 997fb748878a..d18488e5aef0 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonToTritonCPU ConvertDotOp.cpp ConvertElementwiseOps.cpp ConvertMemoryOps.cpp + ConvertReductionOp.cpp ConvertPtrOps.cpp Pipeline.cpp TypeConverter.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp new file mode 100644 index 000000000000..a7c1a28d95b1 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -0,0 +1,218 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTREDUCTIONOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ReductionConversionTarget : public ConversionTarget { +public: + explicit ReductionConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + } +}; + +struct ReduceOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + // Currently, only simple reductions with a single input argumet are + // supported. + // TODO: support generic case. + if (op.getNumOperands() != 1 || op.getNumResults() != 1) + return failure(); + + Value src = rewriter.getRemappedValue(op.getOperand(0)); + VectorType srcTy = dyn_cast(src.getType()); + assert(srcTy); + + Block *block = op.getBody(); + if (block->getNumArguments() != 2) + return failure(); + Value itArg = block->getArgument(0); + Value accArg = block->getArgument(1); + + auto &blockOps = block->getOperations(); + if (blockOps.size() != 2) + return failure(); + + Operation &retOp = blockOps.back(); + if (!isa(retOp) || retOp.getNumOperands() != 1) + return failure(); + + Value retVal = retOp.getOperand(0); + Operation *defOp = retVal.getDefiningOp(); + if (!defOp || defOp->getNumOperands() != 2) + return failure(); + + Value lhs = defOp->getOperand(0); + Value rhs = defOp->getOperand(1); + if ((lhs != itArg || rhs != accArg) && (lhs != accArg || rhs != itArg)) + return failure(); + + vector::CombiningKind reductionKind; + if (failed(detectReductionKind(defOp, reductionKind))) + return failure(); + + Type resTy = getTypeConverter()->convertType(op.getType(0)); + Value acc = buildInitValue(op.getLoc(), resTy, reductionKind, rewriter); + int64_t axis = op.getAxis(); + rewriter.replaceOpWithNewOp( + op, resTy, reductionKind, src, acc, axis); + return success(); + } + + LogicalResult detectReductionKind(Operation *op, + vector::CombiningKind &out) const { + if (isa(op)) + out = vector::CombiningKind::ADD; + else if (isa(op)) + out = vector::CombiningKind::MUL; + else if (isa(op)) + out = vector::CombiningKind::MINSI; + else if (isa(op)) + out = vector::CombiningKind::MINUI; + else if (isa(op)) + out = vector::CombiningKind::MINIMUMF; + else if (isa(op)) + out = vector::CombiningKind::MINNUMF; + else if (isa(op)) + out = vector::CombiningKind::MAXSI; + else if (isa(op)) + out = vector::CombiningKind::MAXUI; + else if (isa(op)) + out = vector::CombiningKind::MAXIMUMF; + else if (isa(op)) + out = vector::CombiningKind::MAXNUMF; + else if (isa(op)) + out = vector::CombiningKind::AND; + else if (isa(op)) + out = vector::CombiningKind::OR; + else if (isa(op)) + out = vector::CombiningKind::XOR; + else + return failure(); + return success(); + } + + Value buildInitValue(Location loc, Type resTy, vector::CombiningKind kind, + ConversionPatternRewriter &rewriter) const { + VectorType vecTy = dyn_cast(resTy); + Type elemTy = vecTy ? vecTy.getElementType() : resTy; + + TypedAttr initVal; + if (kind == vector::CombiningKind::ADD || + kind == vector::CombiningKind::OR || + kind == vector::CombiningKind::XOR || + kind == vector::CombiningKind::MAXUI) + initVal = rewriter.getZeroAttr(elemTy); + else if (kind == vector::CombiningKind::MUL) + initVal = rewriter.getOneAttr(elemTy); + else if (kind == vector::CombiningKind::AND || + kind == vector::CombiningKind::MINUI) + initVal = rewriter.getIntegerAttr(elemTy, -1); + else if (kind == vector::CombiningKind::MAXSI) + initVal = rewriter.getIntegerAttr( + elemTy, + static_cast(1UL << (elemTy.getIntOrFloatBitWidth() - 1))); + else if (kind == vector::CombiningKind::MINSI) + initVal = rewriter.getIntegerAttr( + elemTy, static_cast( + 1UL << (elemTy.getIntOrFloatBitWidth() - 1) - 1)); + else if (kind == vector::CombiningKind::MINIMUMF || + kind == vector::CombiningKind::MINNUMF) { + if (elemTy.isF32()) + initVal = + rewriter.getF32FloatAttr(std::numeric_limits::infinity()); + else if (elemTy.isF64()) + initVal = + rewriter.getF64FloatAttr(std::numeric_limits::infinity()); + else + llvm_unreachable("Unsupported type for acc init value."); + } else if (kind == vector::CombiningKind::MAXIMUMF || + kind == vector::CombiningKind::MAXNUMF) { + if (elemTy.isF32()) + initVal = + rewriter.getF32FloatAttr(-std::numeric_limits::infinity()); + else if (elemTy.isF64()) + initVal = + rewriter.getF64FloatAttr(-std::numeric_limits::infinity()); + else + llvm_unreachable("Unsupported type for acc init value."); + } + + if (vecTy) + initVal = SplatElementsAttr::get(vecTy, initVal); + + return rewriter.create(loc, resTy, initVal); + } +}; + +struct ConvertReductionOp + : public triton::impl::ConvertReductionOpBase { + using ConvertReductionOpBase::ConvertReductionOpBase; + + ConvertReductionOp() : ConvertReductionOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ReductionConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertReductionOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp index d954142d9172..87d72f7a6473 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -12,6 +12,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertPtrOps()); pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); pm.addPass(mlir::triton::cpu::createConvertDotOp()); + pm.addPass(mlir::triton::cpu::createConvertReductionOp()); pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); } diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 8065098becbe..fa4eb818dce5 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -42,6 +42,10 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { opts.enableLowerTensors(lower_tensors); pm.addPass(mlir::createConvertVectorToSCFPass(opts)); }); + m.def("add_lower_vector_multi_dim", [](mlir::PassManager &pm) { + pm.addNestedPass( + mlir::triton::cpu::createLowerMultiReductionPass()); + }); m.def("add_vector_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createConvertVectorToLLVMPass()); }); From ad823a3ac41cce40d9421b2ea4f74b25de23dbae Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 29 May 2024 15:23:30 -0500 Subject: [PATCH 016/165] Support tl.histogram for CPU. (#12) Signed-off-by: Ilya Enkovich Co-authored-by: Minjang Kim --- python/test/unit/language/test_core.py | 1 + .../cpu/include/TritonToTritonCPU/Passes.h | 1 + .../cpu/include/TritonToTritonCPU/Passes.td | 11 ++ .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 1 + .../TritonToTritonCPU/ConvertHistogramOp.cpp | 134 ++++++++++++++++++ .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 1 + 6 files changed, 149 insertions(+) create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e04d99612dfa..35c5da7a27b7 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2709,6 +2709,7 @@ def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.const # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) def test_histogram(M, N, device): diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 5893c99f250e..f67c2de7e2ce 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -23,6 +23,7 @@ std::unique_ptr> createConvertMemoryOps(); std::unique_ptr> createConvertPtrOps(); std::unique_ptr> createConvertDotOp(); std::unique_ptr> createConvertControlFlowOps(); +std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index a2663bea5589..6604ca4fcc12 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -74,6 +74,17 @@ def ConvertControlFlowOps : Pass<"triton-cpu-convert-control-flow-op", "mlir::Mo "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertHistogramOp : Pass<"triton-cpu-convert-histogram-op", "mlir::ModuleOp"> { + let summary = "Convert Triton HistogramOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertHistogramOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> { let summary = "Convert Triton ReduceOp."; let description = [{ diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index d18488e5aef0..d7974fe63079 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonToTritonCPU ConvertControlFlowOps.cpp ConvertDotOp.cpp ConvertElementwiseOps.cpp + ConvertHistogramOp.cpp ConvertMemoryOps.cpp ConvertReductionOp.cpp ConvertPtrOps.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp new file mode 100644 index 000000000000..0bcbfcc9f264 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp @@ -0,0 +1,134 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTHISTOGRAMOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class HistogramConversionTarget : public ConversionTarget { +public: + explicit HistogramConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + + addIllegalOp(); + } +}; + +struct HistogramOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcTy = dyn_cast(src.getType()); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + + if (srcTy.getRank() != 1) + llvm_unreachable("unsupported input for histogram op (rank != 1)"); + + Value zero = rewriter.create( + loc, resTy, rewriter.getZeroAttr(resTy)); + Value one = rewriter.create(loc, resTy, + rewriter.getOneAttr(resTy)); + VectorType cmpVecTy = + VectorType::get(resTy.getShape(), srcTy.getElementType()); + Value rangeVec = rewriter.create( + loc, resTy, makeRangeAttr(cmpVecTy, rewriter)); + Value res = zero; + for (int64_t i = 0; i < srcTy.getShape()[0]; ++i) { + Value idx = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(i)); + Value elem = rewriter.create(loc, src, idx); + Value elemVec = rewriter.create(loc, cmpVecTy, elem); + Value mask = rewriter.create(loc, arith::CmpIPredicate::eq, + elemVec, rangeVec); + Value delta = vector::selectPassthru(rewriter, mask, one, zero); + res = rewriter.create(loc, res, delta); + } + + rewriter.replaceOp(op, res); + + return success(); + } + + TypedAttr makeRangeAttr(VectorType resTy, + ConversionPatternRewriter &rewriter) const { + Type elemTy = resTy.getElementType(); + if (elemTy.isInteger(32)) { + SmallVector range(resTy.getShape()[0]); + std::iota(range.begin(), range.end(), 0); + return rewriter.getI32VectorAttr(range); + } else if (elemTy.isInteger(64)) { + SmallVector range(resTy.getShape()[0]); + std::iota(range.begin(), range.end(), 0); + return rewriter.getI64VectorAttr(range); + } else { + llvm_unreachable( + "unsupported src elem type for histogram (expected i32 or i64)"); + } + } +}; + +struct ConvertHistogramOp + : public triton::impl::ConvertHistogramOpBase { + using ConvertHistogramOpBase::ConvertHistogramOpBase; + + ConvertHistogramOp() : ConvertHistogramOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + HistogramConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertHistogramOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp index 87d72f7a6473..50d5814270d7 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -12,6 +12,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertPtrOps()); pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); pm.addPass(mlir::triton::cpu::createConvertDotOp()); + pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); pm.addPass(mlir::triton::cpu::createConvertReductionOp()); pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); From 61a99a0e94a4c069b4edd7c74a07bd3030602b20 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Thu, 30 May 2024 16:28:57 -0700 Subject: [PATCH 017/165] Fix merge and compile errors (#13) --- third_party/cpu/include/TritonToTritonCPU/Passes.td | 3 ++- third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 6604ca4fcc12..60b9942d08cd 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -83,7 +83,8 @@ def ConvertHistogramOp : Pass<"triton-cpu-convert-histogram-op", "mlir::ModuleOp let dependentDialects = ["mlir::arith::ArithDialect", "mlir::memref::MemRefDialect", - "mlir::vector::VectorDialect", + "mlir::vector::VectorDialect"]; +} def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> { let summary = "Convert Triton ReduceOp."; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp index a7c1a28d95b1..c69d8322f82d 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -149,11 +149,11 @@ struct ReduceOpConversion : public OpConversionPattern { else if (kind == vector::CombiningKind::MAXSI) initVal = rewriter.getIntegerAttr( elemTy, - static_cast(1UL << (elemTy.getIntOrFloatBitWidth() - 1))); + static_cast(-(1UL << (elemTy.getIntOrFloatBitWidth() - 1)))); else if (kind == vector::CombiningKind::MINSI) initVal = rewriter.getIntegerAttr( elemTy, static_cast( - 1UL << (elemTy.getIntOrFloatBitWidth() - 1) - 1)); + (1UL << (elemTy.getIntOrFloatBitWidth() - 1)) - 1)); else if (kind == vector::CombiningKind::MINIMUMF || kind == vector::CombiningKind::MINNUMF) { if (elemTy.isF32()) From 1c0986c7b7ed456d35e4a4fca60b4d8d72eac019 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Fri, 31 May 2024 11:10:23 -0700 Subject: [PATCH 018/165] [CPU] Support flexible active driver + update vector-add tutorial (#11) * [CPU] Support flexible active driver + update vector-add tutorial * Update vector-add to run CPU always + optional GPU * Update do_bench for CPU --- python/triton/backends/__init__.py | 3 +- python/triton/backends/driver.py | 11 ------ python/triton/runtime/driver.py | 17 ++++++++ python/triton/testing.py | 62 ++++++++++++++++++++++++------ python/tutorials/01-vector-add.py | 27 ++++++++----- 5 files changed, 88 insertions(+), 32 deletions(-) diff --git a/python/triton/backends/__init__.py b/python/triton/backends/__init__.py index 92ba144ba97b..738ea2fef8bc 100644 --- a/python/triton/backends/__init__.py +++ b/python/triton/backends/__init__.py @@ -28,6 +28,7 @@ def _find_concrete_subclasses(module, base_class): @dataclass(frozen=True) class Backend: + name: str = "" compiler: BaseBackend = None driver: DriverBase = None @@ -42,7 +43,7 @@ def _discover_backends(): continue compiler = _load_module(name, os.path.join(root, name, 'compiler.py')) driver = _load_module(name, os.path.join(root, name, 'driver.py')) - backends[name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), + backends[name] = Backend(name, _find_concrete_subclasses(compiler, BaseBackend), _find_concrete_subclasses(driver, DriverBase)) return backends diff --git a/python/triton/backends/driver.py b/python/triton/backends/driver.py index 72347735476b..6606b21ca8a2 100644 --- a/python/triton/backends/driver.py +++ b/python/triton/backends/driver.py @@ -51,14 +51,3 @@ def __init__(self): # TODO: remove once TMA is cleaned up def assemble_tensormap_to_arg(self, tensormaps_info, args): return args - - -class CPUDriverBase(DriverBase): - - def __init__(self): - # Right now, we just provide dummy functions. - # TODO: Consider better engineering the code only intended for GPU in jit.py. - self.get_device_capability = lambda idx: (0, 0) - self.get_current_stream = lambda idx: 0 - self.get_current_device = lambda: 0 - self.set_current_device = lambda idx: None diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index 4cf1aea8e494..ed3c16978bd2 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -66,5 +66,22 @@ def set_active(self, driver: DriverBase): def reset_active(self): self.active = self.default + def set_active_to_cpu(self): + if "cpu" not in backends: + raise RuntimeError("CPU backend is unavailable") + self.active = backends["cpu"].driver() + + def set_active_to_gpu(self): + active_gpus = [(name, backend.driver) + for name, backend in backends.items() + if backend.driver.is_active() and name != "cpu"] + if len(active_gpus) != 1: + raise RuntimeError(f"{len(active_gpus)} active GPU drivers ({active_gpus}). There should only be one GPU.") + self.active = active_gpus[0][1]() + return active_gpus[0][0] + + def get_active_gpus(self): + return [name for name, backend in backends.items() if backend.driver.is_active() and name != "cpu"] + driver = DriverConfig() diff --git a/python/triton/testing.py b/python/triton/testing.py index 9a338f11ccc9..54e7ee7c14ed 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -4,12 +4,35 @@ import statistics import subprocess import sys +import time from contextlib import contextmanager from typing import Any, Dict, List from . import language as tl from . import runtime +class Event: + + def __init__(self, is_cpu): + self.time = 0 + self.is_cpu = is_cpu + if not is_cpu: + import torch + self.cuda_event = torch.cuda.Event(enable_timing=True) + + def elapsed_time(self, end_event) -> float: + if self.is_cpu: + return (end_event.time - self.time) * 1000 + else: + return self.cuda_event.elapsed_time(end_event.cuda_event) + + def record(self): + if self.is_cpu: + self.time = time.perf_counter() + else: + self.cuda_event.record() + + def nvsmi(attrs): attrs = ','.join(attrs) cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] @@ -120,7 +143,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod return _summarize_statistics(ret, quantiles, return_mode) -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", is_cpu=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -140,29 +163,46 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m """ assert return_mode in ["min", "max", "mean", "median", "all"] - di = runtime.driver.active.get_device_interface() + if not is_cpu: + di = runtime.driver.active.get_device_interface() fn() - di.synchronize() - - cache = runtime.driver.active.get_empty_cache_for_benchmark() + if not is_cpu: + di.synchronize() + + if not is_cpu: + cache = runtime.driver.active.get_empty_cache_for_benchmark() + if is_cpu: + # Currently, a typical L3 cache size for high-end server CPUs are ~400MB. + cache_size = 512 * 1024 * 1024 + cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cpu') + + if not is_cpu: + # Estimate the runtime of the function + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) + else: + start_event = Event(is_cpu) + end_event = Event(is_cpu) - # Estimate the runtime of the function - start_event = di.Event(enable_timing=True) - end_event = di.Event(enable_timing=True) start_event.record() for _ in range(5): runtime.driver.active.clear_cache(cache) fn() end_event.record() - di.synchronize() + if not is_cpu: + di.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) n_repeat = max(1, int(rep / estimate_ms)) - start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + if not is_cpu: + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + else: + start_event = [Event(is_cpu) for i in range(n_repeat)] + end_event = [Event(is_cpu) for i in range(n_repeat)] # Warm-up for _ in range(n_warmup): fn() diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index e527e5fc7ac3..a2c98e16ef5f 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -23,7 +23,11 @@ import triton import triton.language as tl +<<<<<<< HEAD DEVICE = triton.runtime.driver.active.get_active_torch_device() +======= +BLOCK_SIZE = 1024 +>>>>>>> 61ecff13b ([CPU] Support flexible active driver + update vector-add tutorial (#11)) @triton.jit @@ -59,7 +63,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. # and (2) enqueue the above kernel with appropriate grid/block sizes: -def add(x: torch.Tensor, y: torch.Tensor): +def add(x: torch.Tensor, y: torch.Tensor, is_cpu): # We need to preallocate the output. output = torch.empty_like(x) assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE @@ -80,7 +84,6 @@ def add(x: torch.Tensor, y: torch.Tensor): # %% # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: - torch.manual_seed(0) size = 98432 x = torch.rand(size, device=DEVICE) @@ -110,21 +113,27 @@ def add(x: torch.Tensor, y: torch.Tensor): x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`. x_log=True, # x axis is logarithmic. line_arg='provider', # Argument name whose value corresponds to a different line in the plot. - line_vals=['triton', 'torch'], # Possible values for `line_arg`. - line_names=['Triton', 'Torch'], # Label name for the lines. - styles=[('blue', '-'), ('green', '-')], # Line styles. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. ylabel='GB/s', # Label name for the y-axis. - plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'vector-add-performance (BLOCK_SIZE={BLOCK_SIZE})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(size, provider): x = torch.rand(size, device=DEVICE, dtype=torch.float32) y = torch.rand(size, device=DEVICE, dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] - if provider == 'torch': + if provider == 'torch-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) - if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles) + elif provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles) + elif provider == 'torch-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) From 0db86511234a604a9500c26fedfb9ad0dabf10cd Mon Sep 17 00:00:00 2001 From: Gregory Shimansky Date: Fri, 7 Jun 2024 15:30:59 -0500 Subject: [PATCH 019/165] Added a simple workflow to run on self-hosted intel runner (#16) Signed-off-by: Gregory Shimansky --- .github/workflows/build-test.yml | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 .github/workflows/build-test.yml diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml new file mode 100644 index 000000000000..2e82cc17cad9 --- /dev/null +++ b/.github/workflows/build-test.yml @@ -0,0 +1,73 @@ +name: Build and test +run-name: ${{ inputs.run_name }} + +on: + workflow_dispatch: + +jobs: + pre-commit: + name: Pre-commit checks + runs-on: + - glados + - cpu + - intel + - x86 + steps: + - name: Print inputs + run: | + echo "${{ toJSON(github.event.inputs) }}" + echo INSTALL_IPEX=${{ env.INSTALL_IPEX }} + + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: pip + + - name: Run pre-commit checks + run: | + pip install --upgrade pre-commit + + # TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed + python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true + # If first run of yapf worked and made changes reset the tree to the original state + git reset --hard + + python3 -m pre_commit run --show-diff-on-failure --color=always --all-files --verbose + + build-test: + name: Build and test + runs-on: + - glados + - cpu + - intel + - x86 + strategy: + matrix: + python: ['3.11'] + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install Python ${{ matrix.python }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + cache: pip + + - name: Install pip dependencies + run: | + python3 -m pip install --upgrade pip + python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit + - name: Install Triton + run: | + echo "PATH is '$PATH'" + cd python + python3 -m pip install --no-build-isolation -vvv '.[tests]' + - name: Run python unit tests + run: | + cd python/test/unit + python -m pytest -n 32 --device cpu python/test/unit/language/test_core.py -m cpu From bdc9462e2a2e2d0f939cd289960b49c5e55d67ce Mon Sep 17 00:00:00 2001 From: Gregory Shimansky Date: Sun, 9 Jun 2024 12:57:15 -0500 Subject: [PATCH 020/165] Fixed build and test workflow for intel self-hosted runner (#17) * Fixed yaml syntax Signed-off-by: Gregory Shimansky * Removed cpu label from run-on Signed-off-by: Gregory Shimansky * Added missing zlib-dev Signed-off-by: Gregory Shimansky * Added missing apt-get update Signed-off-by: Gregory Shimansky * Remove pip cache because on self-hosted runner it slows things down Signed-off-by: Gregory Shimansky * Corrected path to tests Signed-off-by: Gregory Shimansky * Added installation of torch==2.1.2 Signed-off-by: Gregory Shimansky --------- Signed-off-by: Gregory Shimansky --- .github/workflows/build-test.yml | 66 ++++++++++++++++---------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 2e82cc17cad9..7e1ac3e02c83 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -9,7 +9,6 @@ jobs: name: Pre-commit checks runs-on: - glados - - cpu - intel - x86 steps: @@ -25,7 +24,6 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.11' - cache: pip - name: Run pre-commit checks run: | @@ -38,36 +36,38 @@ jobs: python3 -m pre_commit run --show-diff-on-failure --color=always --all-files --verbose - build-test: - name: Build and test - runs-on: - - glados - - cpu - - intel - - x86 - strategy: - matrix: - python: ['3.11'] - steps: - - name: Checkout repository - uses: actions/checkout@v4 + build-test: + name: Build and test + runs-on: + - glados + - intel + - x86 + strategy: + matrix: + python: ['3.11'] + steps: + - name: Checkout repository + uses: actions/checkout@v4 - - name: Install Python ${{ matrix.python }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python }} - cache: pip + - name: Install Python ${{ matrix.python }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + + - name: Install pip and apt dependencies + run: | + python3 -m pip install --upgrade pip + python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit + sudo apt-get update + sudo apt-get install -y zlib1g-dev + pip install torch==2.1.2 - - name: Install pip dependencies - run: | - python3 -m pip install --upgrade pip - python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit - - name: Install Triton - run: | - echo "PATH is '$PATH'" - cd python - python3 -m pip install --no-build-isolation -vvv '.[tests]' - - name: Run python unit tests - run: | - cd python/test/unit - python -m pytest -n 32 --device cpu python/test/unit/language/test_core.py -m cpu + - name: Install Triton + run: | + echo "PATH is '$PATH'" + cd python + python3 -m pip install --no-build-isolation -vvv '.[tests]' + + - name: Run python unit tests + run: | + python -m pytest -n 32 --device cpu python/test/unit/language/test_core.py -m cpu From 79224697a12b13bba7b0366cedfb8d92ec611283 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Mon, 10 Jun 2024 09:22:39 -0700 Subject: [PATCH 021/165] [CPU] Add an OpenMP-based CPU launcher (#15) * [CPU] Add OpenMP launcher * Address the comments * Fix induction variable type * Always use preallocated output buffer for CPU with torch.add --- python/triton/runtime/build.py | 2 +- python/tutorials/01-vector-add.py | 32 ++++++++----- third_party/cpu/backend/driver.py | 75 +++++++++++++++++++++++++++---- 3 files changed, 90 insertions(+), 19 deletions(-) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 19f74254dd6b..4568686be953 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -35,7 +35,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. if src.endswith(".cpp") or src.endswith(".cc"): - cc_cmd += ["-std=c++17"] + cc_cmd += ["-std=c++17", "-fopenmp"] ret = subprocess.check_call(cc_cmd) if ret == 0: return so diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index a2c98e16ef5f..2a660be1fd8d 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -23,12 +23,10 @@ import triton import triton.language as tl -<<<<<<< HEAD DEVICE = triton.runtime.driver.active.get_active_torch_device() -======= -BLOCK_SIZE = 1024 ->>>>>>> 61ecff13b ([CPU] Support flexible active driver + update vector-add tutorial (#11)) - +GPU_BLOCK_SIZE = 1024 +CPU_BLOCK_SIZE = 4096 +USE_GPU = True @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. @@ -76,7 +74,7 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): # - Each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. # - Don't forget to pass meta-parameters as keywords arguments. - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE if is_cpu else GPU_BLOCK_SIZE) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output @@ -119,21 +117,35 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): ylabel='GB/s', # Label name for the y-axis. plot_name= # Name for the plot. Used also as a file name for saving the plot. - f'vector-add-performance (BLOCK_SIZE={BLOCK_SIZE})', + f'vector-add-performance (CPU_BLOCK_SIZE={CPU_BLOCK_SIZE}, GPU_BLOCK_SIZE={GPU_BLOCK_SIZE})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(size, provider): x = torch.rand(size, device=DEVICE, dtype=torch.float32) y = torch.rand(size, device=DEVICE, dtype=torch.float32) + + if DEVICE == 'cpu': + triton.runtime.driver.set_active_to_cpu() + else: + triton.runtime.driver.set_active_to_gpu() + quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, False), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles) elif provider == 'torch-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=True) + # Note that we preallocate the output buffer here to only measure the kernel performance + # without a large chunk of memory allocation. + output = torch.empty_like(x) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles, + is_cpu=True) + elif provider == 'triton-cpu-single': + output = torch.empty_like(x) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, True), quantiles=quantiles, is_cpu=True) + output = torch.empty_like(x) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 5783a0342dbd..1018f64d5b35 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -110,7 +110,6 @@ def __new__(cls): return cls.instance def __init__(self): - pass dirname = os.path.dirname(os.path.realpath(__file__)) mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils") self.load_binary = mod.load_binary @@ -182,14 +181,39 @@ def format_of(ty): # generate glue code src = f""" +#include +#include #include -#include -#include +#include #include +#include +#include +#include +#include +#include #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include -#include + +inline bool getBoolEnv(const std::string &env) {{ + const char *s = std::getenv(env.c_str()); + std::string str(s ? s : ""); + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) {{ return std::tolower(c); }}); + return str == "on" || str == "true" || str == "1"; +}} + +inline std::optional getIntEnv(const std::string &env) {{ + const char *cstr = std::getenv(env.c_str()); + if (!cstr) + return std::nullopt; + + char *endptr; + long int result = std::strtol(cstr, &endptr, 10); + if (endptr == cstr) + assert(false && "invalid integer"); + return result; +}} using kernel_ptr_t = void(*)({kernel_fn_arg_types}); @@ -233,20 +257,55 @@ def format_of(ty): return ptr_info; }} -static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ - // TODO: add OMP pragmas to run in parallel +static std::unique_ptr get_all_grids(uint32_t gridX, uint32_t gridY, uint32_t gridZ) {{ + std::unique_ptr grids(new uint32_t[gridX * gridY * gridZ][3]); + // TODO: which order would be more effective for cache locality? for (uint32_t z = 0; z < gridZ; ++z) {{ for (uint32_t y = 0; y < gridY; ++y) {{ for (uint32_t x = 0; x < gridX; ++x) {{ - (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); + grids[z * gridY * gridX + y * gridX + x][0] = x; + grids[z * gridY * gridX + y * gridX + x][1] = y; + grids[z * gridY * gridX + y * gridX + x][2] = z; }} }} }} + return grids; }} -static PyObject* launch(PyObject* self, PyObject* args) {{ +static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ + // TODO: Consider using omp collapse(3) clause for simplicity? + auto all_grids = get_all_grids(gridX, gridY, gridZ); + size_t N = gridX * gridY * gridZ; + + if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{ + if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) + printf("Single core launcher\\n"); + + for (size_t i = 0; i < N; ++i) {{ + const auto [x, y, z] = all_grids[i]; + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); + }} + return; + }} + std::optional max_threads = getIntEnv("TRITON_CPU_MAX_THREADS"); + if (max_threads.has_value()) + max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads())); + else + max_threads = omp_get_max_threads(); + if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) + printf("N: %zu, max_threads: %d\\n", N, max_threads.value()); + + // For now, use the default chunk size, total iterations / max_threads. +#pragma omp parallel for schedule(static) num_threads(max_threads.value()) + for (size_t i = 0; i < N; ++i) {{ + const auto [x, y, z] = all_grids[i]; + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); + }} +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ int gridX, gridY, gridZ; PyObject *launch_enter_hook = NULL; PyObject *launch_exit_hook = NULL; From 508dff5e26d93fe05c1147d97d9e43ce0f1bac4c Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 10 Jun 2024 14:44:21 -0500 Subject: [PATCH 022/165] Support generic reduction and scan cases. (#14) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 15 +- .../cpu/include/TritonToTritonCPU/Passes.h | 1 + .../cpu/include/TritonToTritonCPU/Passes.td | 14 + .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 3 +- .../TritonToTritonCPU/ConvertReductionOp.cpp | 112 ++++++-- .../lib/TritonToTritonCPU/ConvertScanOp.cpp | 156 +++++++++++ .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 1 + .../lib/TritonToTritonCPU/ReduceScanCommon.h | 244 ++++++++++++++++++ 8 files changed, 519 insertions(+), 27 deletions(-) create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 35c5da7a27b7..368343accaf6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2245,6 +2245,7 @@ def deserialize_fp8(np_data, in_dtype): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_max_returns_zero(device): # Simple test with a tl.max call that returns 0. The interpreter had a bug @@ -2271,6 +2272,7 @@ def get_reduced_dtype(dtype_str, op): return dtype_str +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ 'min', @@ -2388,9 +2390,6 @@ def kernel(X, Z, BLOCK: tl.constexpr): def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested - if is_cpu() and op in ('argmin', 'argmax'): - pytest.skip(f"Not yet implemented on CPU: {op}") - @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr, USE_I1: tl.constexpr): @@ -2566,17 +2565,24 @@ def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): check_type_supported(dtype_str, device) if dtype_str == 'bfloat16': - if op == 'cummax': + if is_cuda() and op == 'cummax': pytest.skip("bfloat16 compare not supported before sm90") if op == 'linear_recurrence': pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + # bf16 vector cast is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/92471 + # TODO: Remove the change after the bug is fixed. + if is_cpu() and dtype_str == 'bfloat16': + shape = (min(shape[0], 128), min(shape[1], 128)) + # triton kernel @triton.jit def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): @@ -3240,6 +3246,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis, tmp_path: pathli np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) +@pytest.mark.cpu @pytest.mark.interpreter def test_generic_reduction(device): diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index f67c2de7e2ce..c7d072ab9175 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -25,6 +25,7 @@ std::unique_ptr> createConvertDotOp(); std::unique_ptr> createConvertControlFlowOps(); std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); +std::unique_ptr> createConvertScanOp(); void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); void registerTritonToTritonCPUPipeline(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 60b9942d08cd..28ad258c38c0 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -100,4 +100,18 @@ def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertScanOp : Pass<"triton-cpu-convert-scan", "mlir::ModuleOp"> { + let summary = "Convert Triton ScanOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertScanOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index d7974fe63079..fc22e12b867d 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -4,8 +4,9 @@ add_triton_library(TritonToTritonCPU ConvertElementwiseOps.cpp ConvertHistogramOp.cpp ConvertMemoryOps.cpp - ConvertReductionOp.cpp ConvertPtrOps.cpp + ConvertReductionOp.cpp + ConvertScanOp.cpp Pipeline.cpp TypeConverter.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp index c69d8322f82d..d3a76d9a841b 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -1,22 +1,17 @@ +#include "ReduceScanCommon.h" #include "TypeConverter.h" #include "cpu/include/TritonToTritonCPU/Passes.h" -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include + namespace mlir { namespace triton { #define GEN_PASS_DEF_CONVERTREDUCTIONOP @@ -44,28 +39,91 @@ class ReductionConversionTarget : public ConversionTarget { } }; -struct ReduceOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ReduceOpConversion + : public ReduceScanOpConversionBase { + using ReduceScanOpConversionBase::ReduceScanOpConversionBase; LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MLIRContext *ctx = op.getContext(); - // Currently, only simple reductions with a single input argumet are - // supported. - // TODO: support generic case. + // More simple cases with a single input and a single combine + // operation can utilize target-specific reduction operations like + // horizaontal vector operations. We detect such cases here and map + // them to the vector::MultiDimReductionOp. + if (succeeded(mapToMultiDimReductionOp(op, rewriter))) + return success(); + + return ReduceScanOpConversionBase::matchAndRewrite(op, adaptor, rewriter); + } + + SmallVector + lower1DInput(ValueRange inputs, ReduceOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + int64_t vecSize = cast(inputs[0].getType()).getShape()[0]; + SmallVector range(vecSize); + std::iota(range.begin(), range.end(), 0); + + ArrayRef dummies = createShuffleDummies(loc, inputs, rewriter); + SmallVector res = inputs; + for (int64_t stride = vecSize / 2; stride > 0; stride = stride / 2) { + SmallVector shuffleIndices = range; + for (int64_t i = 0; i < stride; ++i) { + std::swap(shuffleIndices[i], shuffleIndices[i + stride]); + } + SmallVector shuffledInput; + for (auto [val, dummy] : llvm::zip(res, dummies)) { + shuffledInput.push_back(rewriter.create( + loc, val, dummy, shuffleIndices)); + } + + res = accumulate(shuffledInput, res, combineOp, rewriter); + } + + // The results are in the first element of each produced vector. + Value zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create(loc, res[i], zero); + } + return res; + } + + SmallVector + lowerLeadingDimension(ValueRange inputs, ReduceOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + auto shape = cast(inputs[0].getType()).getShape(); + SmallVector res; + for (int64_t idx = 0; idx < shape[0]; ++idx) { + SmallVector subInputs(inputs.size()); + std::transform(inputs.begin(), inputs.end(), subInputs.begin(), + [&](auto val) { + return rewriter.create(loc, val, idx); + }); + + res = accumulate(subInputs, res, combineOp, rewriter); + } + return res; + } + + LogicalResult + mapToMultiDimReductionOp(triton::ReduceOp op, + ConversionPatternRewriter &rewriter) const { if (op.getNumOperands() != 1 || op.getNumResults() != 1) return failure(); Value src = rewriter.getRemappedValue(op.getOperand(0)); - VectorType srcTy = dyn_cast(src.getType()); - assert(srcTy); + VectorType srcTy = cast(src.getType()); Block *block = op.getBody(); if (block->getNumArguments() != 2) return failure(); - Value itArg = block->getArgument(0); - Value accArg = block->getArgument(1); + Value accArg = block->getArgument(0); + Value itArg = block->getArgument(1); auto &blockOps = block->getOperations(); if (blockOps.size() != 2) @@ -155,7 +213,18 @@ struct ReduceOpConversion : public OpConversionPattern { elemTy, static_cast( (1UL << (elemTy.getIntOrFloatBitWidth() - 1)) - 1)); else if (kind == vector::CombiningKind::MINIMUMF || - kind == vector::CombiningKind::MINNUMF) { + kind == vector::CombiningKind::MAXIMUMF) { + if (elemTy.isF32()) + initVal = + rewriter.getF32FloatAttr(std::numeric_limits::quiet_NaN()); + else if (elemTy.isF64()) + initVal = + rewriter.getF64FloatAttr(std::numeric_limits::quiet_NaN()); + else + llvm_unreachable("Unsupported type for acc init value."); + } + + else if (kind == vector::CombiningKind::MINNUMF) { if (elemTy.isF32()) initVal = rewriter.getF32FloatAttr(std::numeric_limits::infinity()); @@ -164,8 +233,7 @@ struct ReduceOpConversion : public OpConversionPattern { rewriter.getF64FloatAttr(std::numeric_limits::infinity()); else llvm_unreachable("Unsupported type for acc init value."); - } else if (kind == vector::CombiningKind::MAXIMUMF || - kind == vector::CombiningKind::MAXNUMF) { + } else if (kind == vector::CombiningKind::MAXNUMF) { if (elemTy.isF32()) initVal = rewriter.getF32FloatAttr(-std::numeric_limits::infinity()); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp new file mode 100644 index 000000000000..5425b5dbf800 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp @@ -0,0 +1,156 @@ +#include "ReduceScanCommon.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTSCANOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ScanConversionTarget : public ConversionTarget { +public: + explicit ScanConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + } +}; + +struct ScanOpConversion + : public ReduceScanOpConversionBase { + using ReduceScanOpConversionBase::ReduceScanOpConversionBase; + + SmallVector + lower1DInput(ValueRange inputs, ScanOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + bool reverse = op.getReverse(); + int64_t vecSize = cast(inputs[0].getType()).getShape()[0]; + Type maskTy = VectorType::get(vecSize, rewriter.getI1Type()); + + ArrayRef dummies = createShuffleDummies(loc, inputs, rewriter); + SmallVector res = inputs; + for (int64_t stride = 1; stride < vecSize; stride *= 2) { + SmallVector shuffleIndices(vecSize, 0); + int64_t start = reverse ? vecSize - 1 - stride : stride; + int64_t end = reverse ? -1 : vecSize; + int64_t step = reverse ? -1 : 1; + for (int64_t i = start; i != end; i += step) { + shuffleIndices[i] = i - step * stride; + } + SmallVector shuffledInput; + for (auto [val, dummy] : llvm::zip(res, dummies)) { + shuffledInput.push_back(rewriter.create( + loc, val, dummy, shuffleIndices)); + } + + auto newRes = accumulate(res, shuffledInput, combineOp, rewriter); + + // Number of already computed elements is equal to the current + // stride. Mask them out using a constant mask. + SmallVector maskVals(vecSize, true); + if (reverse) { + std::fill(maskVals.rbegin(), maskVals.rbegin() + stride, false); + } else { + std::fill(maskVals.begin(), maskVals.begin() + stride, false); + } + Value mask = rewriter.create( + loc, maskTy, rewriter.getBoolVectorAttr(maskVals)); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = vector::selectPassthru(rewriter, mask, newRes[i], res[i]); + } + } + + return res; + } + + SmallVector + lowerLeadingDimension(ValueRange inputs, ScanOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + bool reverse = op.getReverse(); + auto shape = cast(inputs[0].getType()).getShape(); + SmallVector resTypes; + for (const auto &resTy : op.getResultTypes()) { + resTypes.push_back(VectorType::get( + shape, cast(resTy).getElementType())); + } + SmallVector res = makeEmptyResults(loc, resTypes, rewriter); + SmallVector acc; + int64_t start = reverse ? shape[0] - 1 : 0; + int64_t end = reverse ? -1 : shape[0]; + int64_t step = reverse ? -1 : 1; + for (int64_t idx = start; idx != end; idx += step) { + SmallVector subInputs(inputs.size()); + std::transform(inputs.begin(), inputs.end(), subInputs.begin(), + [&](auto val) { + return rewriter.create(loc, val, idx); + }); + + acc = accumulate(subInputs, acc, combineOp, rewriter); + + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create(loc, acc[i], res[i], idx); + } + } + return res; + } +}; + +struct ConvertScanOp : public triton::impl::ConvertScanOpBase { + using ConvertScanOpBase::ConvertScanOpBase; + + ConvertScanOp() : ConvertScanOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ScanConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertScanOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp index 50d5814270d7..2b26cec34248 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -14,6 +14,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertDotOp()); pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); pm.addPass(mlir::triton::cpu::createConvertReductionOp()); + pm.addPass(mlir::triton::cpu::createConvertScanOp()); pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h new file mode 100644 index 000000000000..b2edc5e98b36 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h @@ -0,0 +1,244 @@ +#include "mlir/Transforms/DialectConversion.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include + +namespace mlir { +namespace triton { +namespace cpu { + +// Base class for converting scans and reductions. +// +// It provides accumulation function that clones operations from the +// original combine region and applies them on provided vectors. +// Also, it handles multi-diumensional cases reducing them to two +// possible options: lowering for a 1-D vector inputs and lowering +// the operation over the leading dimension. +// +// Specialized pattern should implement lower1DInput to handle +// trailing dimension case (commonly through shuffles + accumulate) +// and lowerLeadingDimension to handle the leading dimension case +// through accumulation of sub-vectors. +template +struct ReduceScanOpConversionBase : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + using typename OpConversionPattern::OpAdaptor; + + virtual SmallVector + lower1DInput(ValueRange inputs, OpT op, + ConversionPatternRewriter &rewriter) const = 0; + virtual SmallVector + lowerLeadingDimension(ValueRange inputs, OpT op, + ConversionPatternRewriter &rewriter) const = 0; + + LogicalResult + matchAndRewrite(OpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto rank = cast(op.getOperand(0).getType()).getRank(); + if (op.getAxis() == (rank - 1)) + return lowerTrailingDimension(op, rewriter); + + return lowerNonTrailingDimension(op, rewriter); + } + + // To handle the trailing dimension case, we extract all input vectors + // and process them through lower1DInput, then build the resulting + // vector using inserts. + LogicalResult + lowerTrailingDimension(OpT op, ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + SmallVector inputs; + if (failed(rewriter.getRemappedValues(op.getOperands(), inputs))) + return failure(); + + SmallVector inputTys(inputs.size()); + std::transform(inputs.begin(), inputs.end(), inputTys.begin(), + [](auto val) { return cast(val.getType()); }); + + // 1-D input case. + if (inputTys.front().getRank() == 1) { + auto res = lower1DInput(inputs, op, rewriter); + rewriter.replaceOp(op, res); + return success(); + } + + SmallVector res = + makeEmptyResults(loc, op.getResultTypes(), rewriter); + auto shape = inputTys[0].getShape(); + int64_t numElems = inputTys[0].getNumElements(); + auto strides = computeStrides(shape); + // Remove the last stride to produce sub-vector indices. + strides.pop_back(); + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + SmallVector subInputs(inputs.size()); + std::transform( + inputs.begin(), inputs.end(), subInputs.begin(), [&](auto val) { + return rewriter.create(loc, val, indices); + }); + + auto resElems = lower1DInput(subInputs, op, rewriter); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create(loc, resElems[i], res[i], + indices); + } + } + + rewriter.replaceOp(op, res); + return success(); + } + + // In this case we either call lowerLeadingDimension to process the input + // or extract sub-vectors, call lowerLeadingDimension, and then reconstruct + // the result. + LogicalResult + lowerNonTrailingDimension(OpT op, ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + SmallVector inputs; + if (failed(rewriter.getRemappedValues(op.getOperands(), inputs))) + return failure(); + + uint32_t axis = op.getAxis(); + if (axis == 0) { + rewriter.replaceOp(op, lowerLeadingDimension(inputs, op, rewriter)); + return success(); + } + + SmallVector res = + makeEmptyResults(loc, op.getResultTypes(), rewriter); + auto vecTy = cast(inputs[0].getType()); + auto shape = vecTy.getShape(); + auto strides = computeStrides(shape); + // Remove trailing elems to build indices of required rank. + strides.erase(strides.begin() + axis, strides.end()); + int64_t numElems = vecTy.getNumElements(); + int64_t step = strides.back(); + for (int64_t idx = 0; idx < numElems; idx += step) { + auto indices = delinearize(idx, strides); + SmallVector subInputs(inputs.size()); + std::transform( + inputs.begin(), inputs.end(), subInputs.begin(), [&](auto val) { + return rewriter.create(loc, val, indices); + }); + auto resVecs = lowerLeadingDimension(subInputs, op, rewriter); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = + rewriter.create(loc, resVecs[i], res[i], indices); + } + } + + rewriter.replaceOp(op, res); + return success(); + } + + // Accumulate inputs and existing accumulators into a new accumaltors + // applying operations from the combine region. + SmallVector accumulate(ValueRange inputs, ValueRange acc, + Region &combineOp, + ConversionPatternRewriter &rewriter) const { + if (acc.empty()) + return inputs; + + auto shape = cast(inputs[0].getType()).getShape(); + auto &block = combineOp.getBlocks().front(); + IRMapping map; + // Map block arguments to the current inputs and accumulators. + for (unsigned i = 0; i < acc.size(); ++i) { + map.map(block.getArgument(i), acc[i]); + map.map(block.getArgument(acc.size() + i), inputs[i]); + } + for (auto &op : block.getOperations()) { + // Returned values are a new accumulator. + if (isa(op)) { + SmallVector res; + for (auto operand : op.getOperands()) { + res.push_back(map.lookup(operand)); + } + return res; + } + + // Clone operation mapping its inputs and building vector + // result types using the input shape. + OperationState newState(op.getLoc(), op.getName()); + for (auto operand : op.getOperands()) { + newState.operands.push_back( + lookupMappedValue(map, operand, shape, rewriter)); + } + for (auto ty : op.getResultTypes()) { + newState.types.push_back(VectorType::get(shape, ty)); + } + newState.attributes = op.getAttrs(); + auto newOp = rewriter.create(newState); + + // Add new values to the map. + for (auto [oldVal, newVal] : + llvm::zip(op.getResults(), newOp->getResults())) { + map.map(oldVal, newVal); + } + } + llvm_unreachable("No return op found in scan/reduce region"); + } + + Value lookupMappedValue(IRMapping &localMap, Value val, + ArrayRef shape, + ConversionPatternRewriter &rewriter) const { + + Value res = localMap.lookupOrNull(val); + if (!res) { + // If value is not found then it's an invariant defined in the outer + // region. We check if it has been already translated and add a splat + // operation if it hasn't. + res = invariantsMap.lookupOrNull(val); + if (!res) { + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfterValue(val); + res = rewriter.create( + val.getLoc(), VectorType::get(shape, val.getType()), val); + invariantsMap.map(val, res); + rewriter.restoreInsertionPoint(ip); + } + } + return res; + } + + SmallVector + makeEmptyResults(Location loc, TypeRange resTypes, + ConversionPatternRewriter &rewriter) const { + // Initialize results to zero values. + SmallVector res; + for (auto ty : resTypes) { + res.push_back(rewriter.create( + loc, rewriter.getZeroAttr(getTypeConverter()->convertType(ty)))); + } + return res; + } + + // Dummy vectors are required for shuffles that cannot work on a single + // vector. + ArrayRef + createShuffleDummies(Location loc, ValueRange inputs, + ConversionPatternRewriter &rewriter) const { + if (shuffleDummies.empty()) { + for (auto val : inputs) { + auto ty = cast(val.getType()); + shuffleDummies.push_back(rewriter.create( + loc, rewriter.getZeroAttr(ty.cloneWith(1, ty.getElementType())))); + } + } + return shuffleDummies; + } + +private: + mutable IRMapping invariantsMap; + mutable SmallVector shuffleDummies; +}; + +} // namespace cpu +} // namespace triton +} // namespace mlir From e93ef5a1f18d0dfce8ca8b49e5dc05eb74b89df6 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Tue, 11 Jun 2024 08:40:15 -0700 Subject: [PATCH 023/165] [CPU] Dump human-readable asm code in TRITON_CACHE_DIR (#19) * [CPU] Dump human-readable asm code in TRITON_CACHE_DIR * Don't touch the main compiler.py --- third_party/cpu/backend/compiler.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index d48fbf3a96bf..c3f11334750a 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -1,7 +1,6 @@ import functools import hashlib import os -import re from dataclasses import dataclass from typing import Any, Tuple @@ -141,8 +140,12 @@ def make_llir(src, metadata, options): @staticmethod def make_bc(src, metadata, options): if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1": - print("********** Module ASM **********") - print(llvm.translate_to_host_asm(src, options.enable_fp_fusion)) + from triton.runtime.cache import get_cache_manager + + asm = llvm.translate_to_host_asm(src, options.enable_fp_fusion) + fn_cache_manager = get_cache_manager(metadata['hash']) + fn_cache_manager.put(asm, f"{metadata['name']}.asm") + ret = llvm.translate_to_bc(src) return ret From ff40f16891035d4682d46ad63cbd24d6cd7a2f0b Mon Sep 17 00:00:00 2001 From: Gregory Shimansky Date: Tue, 11 Jun 2024 16:06:54 -0500 Subject: [PATCH 024/165] Added g++ installation after switching to ubuntu-22.04 (#21) Signed-off-by: Gregory Shimansky --- .github/workflows/build-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 7e1ac3e02c83..f9f87e6ba1f7 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -59,7 +59,7 @@ jobs: python3 -m pip install --upgrade pip python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit sudo apt-get update - sudo apt-get install -y zlib1g-dev + sudo apt-get install -y zlib1g-dev g++ pip install torch==2.1.2 - name: Install Triton From dbc68ed40fefb1cbb9d57a857d53ebcc20a11a93 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 11 Jun 2024 16:09:00 -0500 Subject: [PATCH 025/165] Support atomic ops for CPU. (#20) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 6 + .../cpu/include/TritonCPUToLLVM/Passes.h | 1 + .../cpu/include/TritonCPUToLLVM/Passes.td | 11 + .../cpu/include/TritonToTritonCPU/Passes.h | 1 + .../cpu/include/TritonToTritonCPU/Passes.td | 14 ++ .../lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp | 154 +++++++++++++ .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 1 + .../cpu/lib/TritonCPUToLLVM/Pipeline.cpp | 1 + .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 1 + .../TritonToTritonCPU/ConvertAtomicOps.cpp | 218 ++++++++++++++++++ .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 1 + 11 files changed, 409 insertions(+) create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 368343accaf6..9cade5c3e89e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1518,6 +1518,7 @@ def kernel(X, Y, Z): # --------------- # test atomics # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "op, dtype_x_str, mode, sem", @@ -1599,6 +1600,7 @@ def kernel(X, Z): assert f"atom.global.gpu.{sem_str}" in h.asm["ptx"] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_atomic_rmw_predicate(num_ctas, device): @@ -1614,6 +1616,7 @@ def kernel(X): assert x.item() == 63 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("shape, axis, num_ctas, dtype_x_str, check_return_val", [(shape, axis, num_ctas, dtype_x_str, check_return_val) @@ -1682,6 +1685,7 @@ def torch_to_triton_dtype(t): np.testing.assert_equal(old_ref, to_numpy(old_tri)) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_tensor_atomic_rmw_block(num_ctas, device): @@ -1701,6 +1705,7 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): assert torch.min(x).item() == 0.0 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -1742,6 +1747,7 @@ def serialized_add(data, Lock, SEM: tl.constexpr): assert f"atom.global.{sem_str}" in h.asm["ptx"] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) @pytest.mark.parametrize("num_ctas", num_ctas_list) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index a1fbce2e4892..7d739f1c32fe 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -24,6 +24,7 @@ std::unique_ptr> createFuncOpToLLVMPass(); std::unique_ptr> createMemoryOpToLLVMPass(); std::unique_ptr> createGetProgramIdOpToLLVMPass(); std::unique_ptr> createLowerMultiReductionPass(); +std::unique_ptr> createAtomicOpsToLLVMPass(); void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); void registerTritonCPUToLLVMPipeline(); diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index 2abe88338dcf..0759ddbf7925 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -54,4 +54,15 @@ def LowerMultiReduction : Pass<"triton-cpu-lower-multi-reduction", "mlir::triton "mlir::triton::TritonDialect"]; } +def AtomicOpsToLLVM : Pass<"triton-cpu-atomic-ops-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton atomic operations to LLVM."; + let description = [{ + }]; + let constructor = "mlir::triton::cpu::createAtomicOpsToLLVMPass()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index c7d072ab9175..b5107e5e78a3 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -26,6 +26,7 @@ std::unique_ptr> createConvertControlFlowOps(); std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); std::unique_ptr> createConvertScanOp(); +std::unique_ptr> createConvertAtomicOps(); void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); void registerTritonToTritonCPUPipeline(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 28ad258c38c0..5dd3bf903440 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -114,4 +114,18 @@ def ConvertScanOp : Pass<"triton-cpu-convert-scan", "mlir::ModuleOp"> { "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertAtomicOps : Pass<"triton-cpu-convert-atomic-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton atomic operations."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertAtomicOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp new file mode 100644 index 000000000000..9a2c183e1c4c --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/AtomicOpsToLLVM.cpp @@ -0,0 +1,154 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_ATOMICOPSTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +LLVM::AtomicOrdering getOrdering(MemSemantic sem) { + switch (sem) { + case MemSemantic::RELAXED: + return LLVM::AtomicOrdering::monotonic; + case MemSemantic::ACQUIRE: + return LLVM::AtomicOrdering::acquire; + case MemSemantic::RELEASE: + return LLVM::AtomicOrdering::release; + case MemSemantic::ACQUIRE_RELEASE: + return LLVM::AtomicOrdering::acq_rel; + default: + llvm_unreachable("Unexpected atomic mem semantic"); + } +} + +// TODO: use enums to access struct fields. +struct AtomicRMWOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto opKind = getAtomicBinOp(op.getAtomicRmwOp(), op.getType()); + auto ptr = rewriter.getRemappedValue(op.getPtr()); + auto val = rewriter.getRemappedValue(op.getVal()); + auto ordering = getOrdering(op.getSem()); + rewriter.replaceOpWithNewOp(op, opKind, ptr, val, + ordering); + return success(); + } + + LLVM::AtomicBinOp getAtomicBinOp(RMWOp op, Type type) const { + switch (op) { + case RMWOp::AND: + return LLVM::AtomicBinOp::_and; + case RMWOp::OR: + return LLVM::AtomicBinOp::_or; + case RMWOp::XOR: + return LLVM::AtomicBinOp::_xor; + case RMWOp::ADD: + return LLVM::AtomicBinOp::add; + case RMWOp::FADD: + return LLVM::AtomicBinOp::fadd; + case RMWOp::MAX: + return type.isIntOrIndex() ? LLVM::AtomicBinOp::max + : LLVM::AtomicBinOp::fmax; + case RMWOp::MIN: + return type.isIntOrIndex() ? LLVM::AtomicBinOp::min + : LLVM::AtomicBinOp::fmin; + case RMWOp::UMAX: + return LLVM::AtomicBinOp::umax; + case RMWOp::UMIN: + return LLVM::AtomicBinOp::umin; + case RMWOp::XCHG: + return LLVM::AtomicBinOp::xchg; + default: + llvm_unreachable("Unexpected atomic op"); + } + } +}; + +struct AtomicCASOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ptr = rewriter.getRemappedValue(op.getPtr()); + auto cmp = rewriter.getRemappedValue(op.getCmp()); + auto val = rewriter.getRemappedValue(op.getVal()); + auto ordering = getOrdering(op.getSem()); + auto failureOrdering = ordering != LLVM::AtomicOrdering::monotonic + ? LLVM::AtomicOrdering::acquire + : ordering; + Value cmpXchg = rewriter.create( + loc, ptr, cmp, val, ordering, failureOrdering); + Value oldVal = rewriter.create(loc, cmpXchg, 0); + rewriter.replaceOp(op, oldVal); + return success(); + } +}; + +struct AtomicOpsToLLVM + : public triton::impl::AtomicOpsToLLVMBase { + using AtomicOpsToLLVMBase::AtomicOpsToLLVMBase; + + AtomicOpsToLLVM() : AtomicOpsToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createAtomicOpsToLLVMPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt index 0cf83bc03b06..9e5f71f8d4e5 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonCPUToLLVM + AtomicOpsToLLVM.cpp FuncOpToLLVM.cpp GetProgramIdOpToLLVM.cpp LowerMultiReduction.cpp diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp index 914f56e668f8..0263a1e65214 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp @@ -11,6 +11,7 @@ void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); + pm.addPass(mlir::triton::cpu::createAtomicOpsToLLVMPass()); // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); } diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index fc22e12b867d..636ea039e718 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonToTritonCPU + ConvertAtomicOps.cpp ConvertControlFlowOps.cpp ConvertDotOp.cpp ConvertElementwiseOps.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp new file mode 100644 index 000000000000..61d3ac65e2fc --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp @@ -0,0 +1,218 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTATOMICOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class AtomicConversionTarget : public ConversionTarget { +public: + explicit AtomicConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addDynamicallyLegalOp( + [&](triton::AtomicRMWOp op) -> std::optional { + return converter.isLegal(op) && !op.getMask(); + }); + addDynamicallyLegalOp( + [&](triton::AtomicCASOp op) -> std::optional { + return converter.isLegal(op); + }); + } +}; + +struct AtomicRMWOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto mask = + op.getMask() ? rewriter.getRemappedValue(op.getMask()) : nullptr; + arith::ConstantOp maskCst = mask ? getConstMaskDef(mask) : nullptr; + auto rmwOp = op.getAtomicRmwOp(); + auto ptrs = rewriter.getRemappedValue(op.getPtr()); + auto vals = rewriter.getRemappedValue(op.getVal()); + auto sem = op.getSem(); + auto scope = op.getScope(); + + if (mask && !isa(mask.getType())) { + auto res = lowerScalarMaskToCF(loc, rmwOp, ptrs, vals, mask, sem, scope, + rewriter); + rewriter.replaceOp(op, res); + return success(); + } + + auto ptrTy = cast(op.getPtr().getType()).getElementType(); + auto vecTy = cast(vals.getType()); + auto strides = computeStrides(vecTy.getShape()); + auto res = + rewriter.create(loc, rewriter.getZeroAttr(vecTy)); + int64_t numElems = vecTy.getNumElements(); + for (int64_t idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + Value resElem; + + if (mask && !maskCst) { + // Non-const mask values are lowered to CF. + Value maskVal = rewriter.create(loc, mask, indices); + resElem = lowerScalarMaskToCF(loc, rmwOp, ptr, val, maskVal, sem, scope, + rewriter); + } else if (!mask || + (maskCst && cast(maskCst.getValue()) + .getValues()[idx])) { + // Const true mask case. + resElem = rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + } + + // Elements with const false mask are skipped. + if (resElem) { + rewriter.create(loc, resElem, res, indices); + } + } + + rewriter.replaceOp(op, res); + return success(); + } + + Value lowerScalarMaskToCF(Location loc, RMWOp rmwOp, Value ptr, Value val, + Value mask, MemSemantic sem, MemSyncScope scope, + ConversionPatternRewriter &rewriter) const { + // Check for constant mask. + if (auto maskDef = mask.getDefiningOp()) { + auto maskVal = cast(maskDef.getValue()); + if (maskVal.getValue().isZero()) { + return rewriter.create( + loc, rewriter.getZeroAttr(val.getType())); + } else { + return rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + } + } + + Block *headerBlock = rewriter.getBlock(); + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(val.getType())); + Block *condBlock = + rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(condBlock); + Value resVal = rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + Block *footerBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + Value res = footerBlock->addArgument(resVal.getType(), resVal.getLoc()); + rewriter.setInsertionPointToEnd(headerBlock); + rewriter.create(loc, mask, condBlock, footerBlock, zero); + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, footerBlock, resVal); + rewriter.setInsertionPointToStart(footerBlock); + + return res; + } + + arith::ConstantOp getConstMaskDef(Value mask) const { + while (auto cast = mask.getDefiningOp()) + mask = cast.getOperand(0); + return mask.getDefiningOp(); + } +}; + +struct AtomicCASOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ptrs = rewriter.getRemappedValue(op.getPtr()); + auto cmpVals = rewriter.getRemappedValue(op.getCmp()); + auto vals = rewriter.getRemappedValue(op.getVal()); + auto sem = op.getSem(); + auto scope = op.getScope(); + auto ptrTy = cast(op.getPtr().getType()).getElementType(); + auto vecTy = cast(vals.getType()); + auto strides = computeStrides(vecTy.getShape()); + auto res = + rewriter.create(loc, rewriter.getZeroAttr(vecTy)); + int64_t numElems = vecTy.getNumElements(); + for (int64_t idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + Value ptr = rewriter.create(loc, ptrs, indices); + ptr = rewriter.create(loc, ptrTy, ptr); + Value val = rewriter.create(loc, vals, indices); + Value cmpVal = rewriter.create(loc, cmpVals, indices); + Value resElem = rewriter.create( + loc, val.getType(), ptr, cmpVal, val, sem, scope); + rewriter.create(loc, resElem, res, indices); + } + + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertAtomicOps + : public triton::impl::ConvertAtomicOpsBase { + using ConvertAtomicOpsBase::ConvertAtomicOpsBase; + + ConvertAtomicOps() : ConvertAtomicOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + AtomicConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertAtomicOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp index 2b26cec34248..ec7c62f72f52 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -16,6 +16,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertReductionOp()); pm.addPass(mlir::triton::cpu::createConvertScanOp()); pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); + pm.addPass(mlir::triton::cpu::createConvertAtomicOps()); // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); } From 2af366d7dde72b611e359416a3d3ae112cef3adf Mon Sep 17 00:00:00 2001 From: RuiqiGao Date: Thu, 13 Jun 2024 21:33:46 -0400 Subject: [PATCH 026/165] [TUTORIAL] Add unmasked matrix multiply example to triton-cpu (#23) * add un-masked tiled matrix-multiplication for triton-cpu * clean and add comment * move test under tutorials --- .../tutorials/03-matrix-multiplication-cpu.py | 394 ++++++++++++++++++ 1 file changed, 394 insertions(+) create mode 100644 python/tutorials/03-matrix-multiplication-cpu.py diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py new file mode 100644 index 000000000000..0cc90a474052 --- /dev/null +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -0,0 +1,394 @@ +""" +Matrix Multiplication +===================== +In this tutorial, you will write a very short high-performance FP32 matrix multiplication kernel. + +You will specifically learn about: + +* Block-level matrix multiplications. + +* Multi-dimensional pointer arithmetic. + +* Program re-ordering for improved L2 cache hit rate. + +* Automatic performance tuning. + +""" + +# %% +# Motivations +# ----------- +# +# Matrix multiplications are a key building block of most modern high-performance computing systems. +# They are notoriously hard to optimize, hence their implementation is generally done by +# hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). +# Unfortunately, these libraries are often proprietary and cannot be easily customized +# to accommodate the needs of modern deep learning workloads (e.g., fused activation functions). +# In this tutorial, you will learn how to implement efficient matrix multiplications by +# yourself with Triton, in a way that is easy to customize and extend. +# +# Roughly speaking, the kernel that we will write will implement the following blocked +# algorithm to multiply a (M, K) by a (K, N) matrix: +# +# .. code-block:: python +# +# # Do in parallel +# for m in range(0, M, BLOCK_SIZE_M): +# # Do in parallel +# for n in range(0, N, BLOCK_SIZE_N): +# acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) +# for k in range(0, K, BLOCK_SIZE_K): +# a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] +# b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] +# acc += dot(a, b) +# C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc +# +# where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance. + +# %% +# Compute Kernel +# -------------- +# +# The above algorithm is, actually, fairly straightforward to implement in Triton. +# The main difficulty comes from the computation of the memory locations at which blocks +# of :code:`A` and :code:`B` must be read in the inner loop. For that, we need +# multi-dimensional pointer arithmetic. +# +# Pointer Arithmetic +# ~~~~~~~~~~~~~~~~~~~ +# +# For a row-major 2D tensor :code:`X`, the memory location of :code:`X[i, j]` is given +# by :code:`&X[i, j] = X + i*stride_xi + j*stride_xj`. +# Therefore, blocks of pointers for :code:`A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and +# :code:`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as: +# +# .. code-block:: python +# +# &A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1); +# &B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1); +# +# Which means that pointers for blocks of A and B can be initialized (i.e., :code:`k=0`) in Triton as the following +# code. Also note that we need an extra modulo to handle the case where :code:`M` is not a multiple of +# :code:`BLOCK_SIZE_M` or :code:`N` is not a multiple of :code:`BLOCK_SIZE_N`, in which case we can pad the data with +# some useless values, which will not contribute to the results. For the :code:`K` dimension, we will handle that later +# using masking load semantics. +# +# .. code-block:: python +# +# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M +# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# a_ptrs = a_ptr + (offs_am[:, None]*stride_am + offs_k [None, :]*stride_ak) +# b_ptrs = b_ptr + (offs_k [:, None]*stride_bk + offs_bn[None, :]*stride_bn) +# +# And then updated in the inner loop as follows: +# +# .. code-block:: python +# +# a_ptrs += BLOCK_SIZE_K * stride_ak; +# b_ptrs += BLOCK_SIZE_K * stride_bk; +# +# +# L2 Cache Optimizations +# ~~~~~~~~~~~~~~~~~~~~~~ +# +# As mentioned above, each program instance computes a :code:`[BLOCK_SIZE_M, BLOCK_SIZE_N]` +# block of :code:`C`. +# It is important to remember that the order in which these blocks are computed does +# matter, since it affects the L2 cache hit rate of our program, and unfortunately, a +# simple row-major ordering +# +# .. code-block:: Python +# +# pid = triton.program_id(0); +# grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M; +# grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N; +# pid_m = pid / grid_n; +# pid_n = pid % grid_n; +# +# is just not going to cut it. +# +# One possible solution is to launch blocks in an order that promotes data reuse. +# This can be done by 'super-grouping' blocks in groups of :code:`GROUP_M` rows before +# switching to the next column: +# +# .. code-block:: python +# +# # Program ID +# pid = tl.program_id(axis=0) +# # Number of program ids along the M axis +# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# # Number of programs ids along the N axis +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# # Number of programs in group +# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# # Id of the group this program is in +# group_id = pid // num_pid_in_group +# # Row-id of the first program in the group +# first_pid_m = group_id * GROUP_SIZE_M +# # If `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# # *Within groups*, programs are ordered in a column-major order +# # Row-id of the program in the *launch grid* +# pid_m = first_pid_m + (pid % group_size_m) +# # Col-id of the program in the *launch grid* +# pid_n = (pid % num_pid_in_group) // group_size_m +# +# For example, in the following matmul where each matrix is 9 blocks by 9 blocks, +# we can see that if we compute the output in row-major ordering, we need to load 90 +# blocks into SRAM to compute the first 9 output blocks, but if we do it in grouped +# ordering, we only need to load 54 blocks. +# +# .. image:: grouped_vs_row_major_ordering.png +# +# In practice, this can improve the performance of our matrix multiplication kernel by +# more than 10\% on some hardware architecture (e.g., 220 to 245 TFLOPS on A100). +# + +# %% +# Final Result +# ------------ + +import torch + +import triton +import triton.language as tl + + +BLOCK_SIZE_M = 32 +BLOCK_SIZE_N = 32 +BLOCK_SIZE_K = 32 +GROUP_SIZE_M = 8 +USE_GPU = True + +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to matrix C's type after the loop, if C has lower precision type (for example, float16 and bfloat16). + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + + #TODO: Currently masked load is not supported yet. + #a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + #b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Convert the accumulator to the output matrix C's type if needed. + c = accumulator + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + + #TODO: Currently masked load is not supported yet. + #c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + #tl.store(c_ptrs, c, mask=c_mask) + tl.store(c_ptrs, c) + + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + # 1D launch kernel where each block gets its own program. + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, # + GROUP_SIZE_M=GROUP_SIZE_M, # + ) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation. + +torch.manual_seed(0) + +triton.runtime.driver.set_active_to_cpu() + + +a = torch.randn((512, 512), device='cpu', dtype=torch.float32) +b = torch.randn((512, 512), device='cpu', dtype=torch.float32) +triton_output = matmul(a, b) +torch_output = torch.matmul(a, b) +print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}") +print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}") +rtol = 0 +if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ TritonCPU and TorchCPU match") +else: + print("❌ TritonCPU and TorchCPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}') + +# %% +# Benchmark +# --------- +# +# Square Matrix Performance +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices, +# but feel free to arrange this script as you wish to benchmark any other matrix shape. + +LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU'] +LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + a = a.to('cuda') + b = b.to('cuda') + triton_output = matmul(a, b) + torch_output = torch.matmul(a, b) + print(f"triton_gpu_output_with_{a.dtype}_inputs={triton_output}") + print(f"torch_gpu_output_with_{a.dtype}_inputs={torch_output}") + rtol = 0 + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ TritonGPU and TorchGPU match") + else: + print("❌ TritonGPU and TorchGPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('yellow', '-'), ('red', '-')] + + +# %% +# Seems like we're good to go! + +# %% +# Benchmark +# --------- +# +# We can now benchmark our custom op on vectors of increasing sizes to get a sense of how it does relative to PyTorch. +# To make things easier, Triton has a set of built-in utilities that allow us to concisely plot the performance of our custom ops. +# for different problem sizes. + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 21)], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel='GFLOPS', # Label name for the y-axis. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', + args={}, # Values for function arguments not in `x_names` and `y_name`. + )) + +def benchmark(M, N, K, provider): + import os + + device = 'cpu' if 'cpu' in provider else 'cuda' + a = torch.randn((M, K), device=device, dtype=torch.float32) + b = torch.randn((K, N), device=device, dtype=torch.float32) + + if device == 'cpu': + triton.runtime.driver.set_active_to_cpu() + if 'single' in provider: + os.environ['TRITON_CPU_SINGLE_CORE'] = '1' + else: + os.unsetenv('TRITON_CPU_SINGLE_CORE') + else: + triton.runtime.driver.set_active_to_gpu() + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + elif provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + elif provider == 'torch-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + elif provider == 'triton-cpu-single': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + elif provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) From b8334a48124958a11f05ac04494b6d6a66a49061 Mon Sep 17 00:00:00 2001 From: RuiqiGao Date: Fri, 14 Jun 2024 21:33:40 -0400 Subject: [PATCH 027/165] Update matrix-multiplication-cpu tutorial, use preallocated output buffer for CPU. (#24) --- .../tutorials/03-matrix-multiplication-cpu.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 0cc90a474052..c20a36aab10e 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -249,15 +249,19 @@ def matmul_kernel( # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. -def matmul(a, b): +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" M, K = a.shape K, N = b.shape + #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) + if c is None: + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + else: + assert c.shape == (M, N), "Incompatible dimensions" # 1D launch kernel where each block gets its own program. grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), ) matmul_kernel[grid]( @@ -284,10 +288,9 @@ def matmul(a, b): triton.runtime.driver.set_active_to_cpu() - a = torch.randn((512, 512), device='cpu', dtype=torch.float32) b = torch.randn((512, 512), device='cpu', dtype=torch.float32) -triton_output = matmul(a, b) +triton_output = matmul(a, b, None) torch_output = torch.matmul(a, b) print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}") print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}") @@ -315,7 +318,7 @@ def matmul(a, b): triton.runtime.driver.set_active_to_gpu() a = a.to('cuda') b = b.to('cuda') - triton_output = matmul(a, b) + triton_output = matmul(a, b, None) torch_output = torch.matmul(a, b) print(f"triton_gpu_output_with_{a.dtype}_inputs={triton_output}") print(f"torch_gpu_output_with_{a.dtype}_inputs={torch_output}") @@ -377,13 +380,16 @@ def benchmark(M, N, K, provider): if provider == 'torch-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) elif provider == 'torch-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), quantiles=quantiles) + c = torch.empty((M, N), device=a.device, dtype=a.dtype) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) From 68c9780a8d66bbc413fb97d2dfdc4520ec39f4a8 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 17 Jun 2024 19:00:14 -0500 Subject: [PATCH 028/165] Fixes for x86 CI workflow (#26) * Fix RelWithDebInfo build. Signed-off-by: Ilya Enkovich * Skip fp8 cast tests on CPU. Signed-off-by: Ilya Enkovich * Fix segfault. Signed-off-by: Ilya Enkovich * [BACKEND] Update LLVM version to https://github.com/llvm/llvm-project/commit/765206e050453018e861637a08a4520f29238074 (#4059) * Add -s option to pytest run. Signed-off-by: Ilya Enkovich * Add a workaround for LLVM bug causing test failure on Skylake CPU. Signed-off-by: Ilya Enkovich * Add a workaround for LLVM fpext bug causing test failure on Skylake CPU. Signed-off-by: Ilya Enkovich * Fix formatting. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich Co-authored-by: Pablo Zimmermann --- .github/workflows/build-test.yml | 2 +- python/src/ir.cc | 2 ++ python/test/unit/language/test_core.py | 15 +++++++++++ .../tutorials/03-matrix-multiplication-cpu.py | 27 +++++++++---------- .../lib/TritonToTritonCPU/ReduceScanCommon.h | 4 ++- 5 files changed, 34 insertions(+), 16 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index f9f87e6ba1f7..8c9bcca7cf28 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -70,4 +70,4 @@ jobs: - name: Run python unit tests run: | - python -m pytest -n 32 --device cpu python/test/unit/language/test_core.py -m cpu + python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu diff --git a/python/src/ir.cc b/python/src/ir.cc index caaefdfdec21..dd70dc13b511 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1845,6 +1845,8 @@ void init_triton_ir(py::module &&m) { llvm::SmallVector debugTypes = parseCommaSeparatedValues(debugOnly, storage); ::llvm::DebugFlag = true; + // For release build setCurrentDebugTypes is a macro, so avoid + // namespace prefix using namespace llvm; setCurrentDebugTypes(debugTypes.data(), debugTypes.size()); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9cade5c3e89e..1229a212c0ad 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1848,6 +1848,15 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"): pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') + if is_cpu() and (dtype_x in torch_float8_dtypes or dtype_z in torch_float8_dtypes): + pytest.skip(f'test_cast{(dtype_x, dtype_z)} is not supported on CPU.') + + # fptrunc fp32->fp16 is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/95274 + # TODO: remove the change after the bug is fixed. + if is_cpu() and dtype_x == "float32" and dtype_z == "float16": + size = 512 + # bf16 vector cast is broken in LLVM for large vectors: # https://github.com/llvm/llvm-project/issues/92471 # TODO: Remove the change after the bug is fixed. @@ -2396,6 +2405,12 @@ def kernel(X, Z, BLOCK: tl.constexpr): def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested + # fpext fp16->fp32 is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/95278 + # TODO: remove the change after the bug is fixed. + if is_cpu() and dtype_str == "float16": + shape = (min(shape[0], 512), min(shape[1], 512)) + @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr, USE_I1: tl.constexpr): diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index c20a36aab10e..e96d04614661 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -154,13 +154,13 @@ import triton import triton.language as tl - BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 32 BLOCK_SIZE_K = 32 GROUP_SIZE_M = 8 USE_GPU = True + @triton.jit def matmul_kernel( # Pointers to matrices @@ -227,7 +227,7 @@ def matmul_kernel( # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk - + # Convert the accumulator to the output matrix C's type if needed. c = accumulator @@ -236,14 +236,13 @@ def matmul_kernel( offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - + #TODO: Currently masked load is not supported yet. #c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) #tl.store(c_ptrs, c, mask=c_mask) tl.store(c_ptrs, c) - # %% # We can now create a convenience wrapper function that only takes two input tensors, # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. @@ -256,9 +255,10 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): M, K = a.shape K, N = b.shape #TODO: Currently masked load is not supported yet. - assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and (K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" + assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( + K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" if c is None: - # Allocates output. + # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) else: assert c.shape == (M, N), "Incompatible dimensions" @@ -270,9 +270,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # ) return c @@ -298,7 +296,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU and TorchCPU match") else: - print("❌ TritonCPU and TorchCPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}') + print("❌ TritonCPU and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') # %% # Benchmark @@ -326,13 +325,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonGPU and TorchGPU match") else: - print("❌ TritonGPU and TorchGPU differ, the maximum difference is "f'{torch.max(torch.abs(triton_output - torch_output))}') + print("❌ TritonGPU and TorchGPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') LINE_VALS += ['triton-gpu', 'torch-gpu'] LINE_NAMES += ['TritonGPU', 'TorchGPU'] LINE_STYLES += [('yellow', '-'), ('red', '-')] - # %% # Seems like we're good to go! @@ -359,7 +358,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) - def benchmark(M, N, K, provider): import os @@ -383,7 +381,8 @@ def benchmark(M, N, K, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) elif provider == 'torch-cpu': c = torch.empty((M, N), device=a.device, dtype=a.dtype) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, + is_cpu=True) elif provider == 'triton-cpu-single': c = torch.empty((M, N), device=a.device, dtype=a.dtype) ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h index b2edc5e98b36..ba2d64d8f5f0 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h +++ b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h @@ -225,10 +225,12 @@ struct ReduceScanOpConversionBase : public OpConversionPattern { createShuffleDummies(Location loc, ValueRange inputs, ConversionPatternRewriter &rewriter) const { if (shuffleDummies.empty()) { + SmallVector dummyShape({1}); for (auto val : inputs) { auto ty = cast(val.getType()); shuffleDummies.push_back(rewriter.create( - loc, rewriter.getZeroAttr(ty.cloneWith(1, ty.getElementType())))); + loc, rewriter.getZeroAttr( + ty.cloneWith(dummyShape, ty.getElementType())))); } } return shuffleDummies; From 45d02ad49a01911354b1494a58fc79f5bec77559 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 20 Jun 2024 12:56:09 -0500 Subject: [PATCH 029/165] Use static compilation for kernels. (#29) Signed-off-by: Ilya Enkovich --- python/triton/runtime/build.py | 2 + third_party/cpu/backend/compiler.py | 16 +- third_party/cpu/backend/driver.cpp | 224 ---------------------------- third_party/cpu/backend/driver.py | 94 +++--------- 4 files changed, 29 insertions(+), 307 deletions(-) delete mode 100644 third_party/cpu/backend/driver.cpp diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 4568686be953..a2072981bae8 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -36,6 +36,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. if src.endswith(".cpp") or src.endswith(".cc"): cc_cmd += ["-std=c++17", "-fopenmp"] + if src.endswith(".s"): + cc_cmd += ["-gdwarf-5"] ret = subprocess.check_call(cc_cmd) if ret == 0: return so diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index c3f11334750a..0a98532eceba 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -42,7 +42,7 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) - self.binary_ext = "bc" + self.binary_ext = "asm" def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -138,22 +138,14 @@ def make_llir(src, metadata, options): return ret @staticmethod - def make_bc(src, metadata, options): - if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1": - from triton.runtime.cache import get_cache_manager - - asm = llvm.translate_to_host_asm(src, options.enable_fp_fusion) - fn_cache_manager = get_cache_manager(metadata['hash']) - fn_cache_manager.put(asm, f"{metadata['name']}.asm") - - ret = llvm.translate_to_bc(src) - return ret + def make_asm(src, metadata, options): + return llvm.translate_to_host_asm(src, options.enable_fp_fusion) def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) - stages["bc"] = lambda src, metadata: self.make_bc(src, metadata, options) + stages["asm"] = lambda src, metadata: self.make_asm(src, metadata, options) @functools.lru_cache() def hash(self): diff --git a/third_party/cpu/backend/driver.cpp b/third_party/cpu/backend/driver.cpp deleted file mode 100644 index babff3dfdebe..000000000000 --- a/third_party/cpu/backend/driver.cpp +++ /dev/null @@ -1,224 +0,0 @@ -//===- driver.cpp ---------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "llvm/Bitcode/BitcodeReader.h" -#include "llvm/ExecutionEngine/Orc/CompileUtils.h" -#include "llvm/ExecutionEngine/Orc/Core.h" -#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" -#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" -#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" -#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" -#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" -#include "llvm/ExecutionEngine/SectionMemoryManager.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/TargetSelect.h" - -#include -#include -#include -#include -#include -#include -#include - -#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION -#include - -static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { - int device_id; - if (!PyArg_ParseTuple(args, "i", &device_id)) - return NULL; - - return Py_BuildValue("{s:i}", "max_shared_mem", 0); -} - -bool getBoolEnv(const std::string &env) { - const char *s = std::getenv(env.c_str()); - std::string str(s ? s : ""); - std::transform(str.begin(), str.end(), str.begin(), - [](unsigned char c) { return std::tolower(c); }); - return (str == "on" || str == "true" || str == "1"); -} - -llvm::orc::ThreadSafeContext &getThreadSafeContext() { - static llvm::orc::ThreadSafeContext tsc; - static std::once_flag init_flag; - std::call_once(init_flag, []() { - auto context = std::make_unique(); - tsc = llvm::orc::ThreadSafeContext(std::move(context)); - }); - return tsc; -} - -std::string llvmErrToString(const llvm::Error &err) { - std::string res; - llvm::raw_string_ostream os(res); - os << err; - return res; -}; - -struct CompiledKernel { - std::unique_ptr execution_session; - std::unique_ptr data_layout; - std::unique_ptr mangle; - std::unique_ptr object_layer; - std::unique_ptr compiler_layer; - llvm::orc::JITDylib *dylib = nullptr; - - CompiledKernel() = default; - CompiledKernel(CompiledKernel &&) = default; - - ~CompiledKernel() { - if (execution_session) - llvm::cantFail(execution_session->endSession()); - } -}; - -std::vector> compiled_kernels; - -static PyObject *loadBitcode(PyObject *self, PyObject *args) { - const char *name; - int shared; - PyObject *py_bytes; - int devId; - - if (!PyArg_ParseTuple(args, "sSii", &name, &py_bytes, &shared, &devId)) { - std::cerr << "loadBitcode arg parse failed" << std::endl; - return NULL; - } - - std::string kernel_name = name; - size_t binary_size = PyBytes_Size(py_bytes); - const char *binary_ptr = PyBytes_AsString(py_bytes); - - llvm::LLVMContext context; - auto buf = llvm::MemoryBuffer::getMemBuffer( - llvm::StringRef(binary_ptr, binary_size)); - auto mod = llvm::parseBitcodeFile(*buf, context); - if (!mod) { - std::cerr << "Failed to parse LLVM bitcode module" << std::endl; - return NULL; - } - - if (getBoolEnv("MLIR_ENABLE_DUMP")) { - llvm::errs() << "********** Loaded Module (kernel_name=" << name - << ") **********\n" - << **mod << "\n"; - } - - auto init_err = llvm::InitializeNativeTarget(); - if (init_err) { - std::cerr << "Failed to initialize native target." << std::endl; - return NULL; - } - - llvm::InitializeNativeTargetAsmPrinter(); - llvm::InitializeNativeTargetAsmParser(); - - auto self_epc = - llvm::cantFail(llvm::orc::SelfExecutorProcessControl::Create()); - - auto detect_host_res = llvm::orc::JITTargetMachineBuilder::detectHost(); - if (!detect_host_res) { - std::cerr << "Failed to initialize JITTargetMachineBuilder: " - << llvmErrToString(detect_host_res.takeError()); - return NULL; - } - llvm::orc::JITTargetMachineBuilder tmb = std::move(*detect_host_res); - - auto data_layout_res = tmb.getDefaultDataLayoutForTarget(); - if (!data_layout_res) { - std::cerr << "Failed to initialize data layout: " - << llvmErrToString(data_layout_res.takeError()); - return NULL; - } - - CompiledKernel kernel; - kernel.execution_session = - std::make_unique(std::move(self_epc)); - kernel.data_layout = - std::make_unique(std::move(*data_layout_res)); - kernel.mangle = std::make_unique( - *kernel.execution_session, *kernel.data_layout); - kernel.object_layer = std::make_unique( - *kernel.execution_session, - []() { return std::make_unique(); }); - kernel.compiler_layer = std::make_unique( - *kernel.execution_session, *kernel.object_layer, - std::make_unique(std::move(tmb))); - - auto dylib_res = kernel.execution_session->createJITDylib("
"); - if (!dylib_res) { - std::cerr << "Failed to create initialize JITDylib: " - << llvmErrToString(dylib_res.takeError()); - return NULL; - } - - kernel.dylib = &(*dylib_res); - kernel.dylib->addGenerator(llvm::cantFail( - llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( - kernel.data_layout->getGlobalPrefix()))); - - // Compile module. - (**mod).setDataLayout(*kernel.data_layout); - llvm::orc::ThreadSafeModule tsm(std::move(*mod), getThreadSafeContext()); - auto err = kernel.compiler_layer->add(*kernel.dylib, std::move(tsm)); - if (err) { - std::cerr << "Cannot add LLVM module: " << llvmErrToString(err); - return NULL; - } - - // Find kernel function pointer. - auto lookup_res = - kernel.execution_session->lookup({kernel.dylib}, (*kernel.mangle)(name)); - if (!lookup_res) { - std::cerr << "Failed to find function " << std::string(name) - << "\nError: " << llvmErrToString(lookup_res.takeError()); - return NULL; - } - uint64_t fn_ptr = lookup_res->getAddress().getValue(); - - compiled_kernels.push_back( - std::make_unique(std::move(kernel))); - auto *kernel_ptr = compiled_kernels.back().get(); - - return Py_BuildValue("(KKii)", reinterpret_cast(kernel_ptr), - reinterpret_cast(fn_ptr), 0, 0); -} - -static PyObject *initContext(PyObject *self, PyObject *args) { - return Py_BuildValue("(K)", (uint64_t)0); -} - -static PyObject *initDevices(PyObject *self, PyObject *args) { - return Py_BuildValue("(i)", 1); -} - -static PyMethodDef ModuleMethods[] = { - {"load_binary", loadBitcode, METH_VARARGS, - "Load provided SPV into ZE driver"}, - {"get_device_properties", getDeviceProperties, METH_VARARGS, - "Get the properties for a given device"}, - {NULL, NULL, 0, NULL} // sentinel -}; - -static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cpu_utils", - NULL, // documentation - -1, // size - ModuleMethods}; - -PyMODINIT_FUNC PyInit_cpu_utils(void) { - PyObject *m = PyModule_Create(&ModuleDef); - if (m == NULL) { - return NULL; - } - PyModule_AddFunctions(m, ModuleMethods); - return m; -} diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 1018f64d5b35..126e41d3bd7a 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -8,74 +8,9 @@ from triton.backends.compiler import GPUTarget dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") -llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm") -llvm_root = os.path.expanduser(llvm_root) -llvm_dirs = os.listdir(llvm_root) -if len(llvm_dirs) == 1: - llvm_root = os.path.join(llvm_root, llvm_dirs[0]) -include_dir = [ - os.path.join(dirname, "include"), - os.path.join(llvm_root, "include"), -] -library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")] -libraries = [ - "LLVMOrcJIT", - "LLVMPasses", - "LLVMX86CodeGen", - "LLVMX86AsmParser", - "LLVMX86Desc", - "LLVMX86Info", - "LLVMGlobalISel", - "LLVMSelectionDAG", - "LLVMHipStdPar", - "LLVMCoroutines", - "LLVMipo", - "LLVMFrontendOpenMP", - "LLVMInstrumentation", - "LLVMAsmPrinter", - "LLVMCodeGen", - "LLVMObjCARCOpts", - "LLVMLinker", - "LLVMVectorize", - "LLVMScalarOpts", - "LLVMInstCombine", - "LLVMFrontendOffloading", - "LLVMExecutionEngine", - "LLVMAggressiveInstCombine", - "LLVMTransformUtils", - "LLVMTarget", - "LLVMRuntimeDyld", - "LLVMJITLink", - "LLVMIRPrinter", - "LLVMBitWriter", - "LLVMAnalysis", - "LLVMProfileData", - "LLVMSymbolize", - "LLVMDebugInfoDWARF", - "LLVMObject", - "LLVMTextAPI", - "LLVMMCParser", - "LLVMMCDisassembler", - "LLVMMC", - "LLVMIRReader", - "LLVMCFGuard", - "LLVMBitReader", - "LLVMAsmParser", - "LLVMCore", - "LLVMBinaryFormat", - "LLVMOrcTargetProcess", - "LLVMTargetParser", - "LLVMRemarks", - "LLVMOrcShared", - "LLVMOption", - "LLVMDebugInfoCodeView", - "LLVMCodeGenTypes", - "LLVMBitstreamReader", - "LLVMSupport", - "LLVMDemangle", - "stdc++", - "z", -] +include_dir = [os.path.join(dirname, "include")] +library_dir = [os.path.join(dirname, "lib")] +libraries = ["stdc++"] def compile_module_from_src(src, name): @@ -110,9 +45,26 @@ def __new__(cls): return cls.instance def __init__(self): - dirname = os.path.dirname(os.path.realpath(__file__)) - mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils") - self.load_binary = mod.load_binary + pass + + def load_binary(self, name, src, shared_mem, device): + # src actually holds asm text, compile to a shared library. + key = hashlib.md5(src).hexdigest() + cache = get_cache_manager(key) + cache_path = cache.get_file(f"{name}.so") + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + asm_path = os.path.join(tmpdir, "kernel.s") + Path(asm_path).write_bytes(src) + Path("kernel.s").write_bytes(src) + so = _build(name, asm_path, tmpdir, library_dir, include_dir, ["gcc", "m"]) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), f"{name}.so", binary=True) + import ctypes + lib = ctypes.cdll.LoadLibrary(cache_path) + fn_ptr = getattr(lib, name) + fn_ptr_as_void_p = ctypes.cast(fn_ptr, ctypes.c_void_p).value + return (fn_ptr, fn_ptr_as_void_p, 0, 0) def get_device_properties(self, *args): return {"max_shared_mem": 0} From f768debbdb86e25e3e2ed4f294dd92008d883d7f Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 20 Jun 2024 14:09:54 -0500 Subject: [PATCH 030/165] Move byte manipulation ops from elwise ops conversion. (#28) Signed-off-by: Ilya Enkovich --- .../cpu/include/TritonToTritonCPU/Passes.h | 1 + .../cpu/include/TritonToTritonCPU/Passes.td | 14 ++ .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 1 + .../TritonToTritonCPU/ConvertElemManipOps.cpp | 208 ++++++++++++++++++ .../ConvertElementwiseOps.cpp | 129 ----------- .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 1 + 6 files changed, 225 insertions(+), 129 deletions(-) create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index b5107e5e78a3..14df893f0bac 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -19,6 +19,7 @@ namespace cpu { #include "cpu/include/TritonToTritonCPU/Passes.h.inc" std::unique_ptr> createConvertElementwiseOps(); +std::unique_ptr> createConvertElemManipOps(); std::unique_ptr> createConvertMemoryOps(); std::unique_ptr> createConvertPtrOps(); std::unique_ptr> createConvertDotOp(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 5dd3bf903440..dfac926a9f5b 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -31,6 +31,20 @@ def ConvertElementwiseOps : Pass<"triton-cpu-convert-elementwise-ops", "mlir::Mo "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertElemManipOps : Pass<"triton-cpu-convert-elem-manip-ops", "mlir::ModuleOp"> { + let summary = "Convert elements manipulation ops (transpose, shuffle, etc.)."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertElemManipOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + def ConvertPtrOps : Pass<"triton-cpu-convert-ptr-ops", "mlir::ModuleOp"> { let summary = "Convert Triton ops related to pointer arithmetics."; let description = [{ diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index 636ea039e718..dc34c5bd0199 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonToTritonCPU ConvertControlFlowOps.cpp ConvertDotOp.cpp ConvertElementwiseOps.cpp + ConvertElemManipOps.cpp ConvertHistogramOp.cpp ConvertMemoryOps.cpp ConvertPtrOps.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp new file mode 100644 index 000000000000..99211ea90e41 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp @@ -0,0 +1,208 @@ +#include "OpTypeConversion.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Membar.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTELEMMANIPOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ElemManipOpConversionTarget : public ConversionTarget { +public: + explicit ElemManipOpConversionTarget(MLIRContext &ctx, + TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + addIllegalOp(); + } +}; + +struct ReshapeOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + assert(isa(op.getType())); + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcShape = dyn_cast(src.getType()).getShape(); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + auto dstShape = resTy.getShape(); + auto elemTy = resTy.getElementType(); + + // There are restrictions on how shape can be modified by ShapeCastOp + // when rank is changed. For now, we simply detect it and handle through + // a cast to 1D vector. Better solution may be required later. + if (canCastShape(srcShape, dstShape)) { + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), src); + } else { + SmallVector tmpShape({resTy.getNumElements()}); + auto tmp = rewriter.create( + loc, VectorType::get(tmpShape, elemTy), src); + rewriter.replaceOpWithNewOp( + op, VectorType::get(dstShape, elemTy), tmp); + } + return success(); + } + +private: + bool canCastShape(ArrayRef src, ArrayRef dst) const { + if (src.size() == dst.size()) + return true; + if (src.size() > dst.size()) + return canCastShape(dst, src); + + size_t srcIdx = 0; + size_t dstIdx = 0; + while (srcIdx < src.size() && dstIdx < dst.size()) { + if (src[srcIdx] == 1) { + ++srcIdx; + } else { + // Source dim size should be a product of continuous dest dim sizes. + int64_t srcSize = src[srcIdx++]; + int64_t dstSize = dst[dstIdx++]; + while (dstSize < srcSize && dstIdx < dst.size()) + dstSize *= dst[dstIdx++]; + if (dstSize != srcSize) + return false; + } + } + + // Skip trailing 1s. + while (srcIdx < src.size() && src[srcIdx] == 1) + ++srcIdx; + while (dstIdx < dst.size() && dst[dstIdx] == 1) + ++dstIdx; + + return srcIdx == src.size() && dstIdx == dst.size(); + } +}; + +struct TransOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto val = rewriter.getRemappedValue(op.getSrc()); + auto order = op.getOrder(); + SmallVector permutation(order.begin(), order.end()); + rewriter.replaceOpWithNewOp(op, val, permutation); + return success(); + } +}; + +struct JoinOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto interleave = rewriter.create(loc, lhs, rhs); + // JoinOp creates a new dimension, but InterleaveOp doubles the final one. + // Use ShapeCastOp to get the required shape. + auto resTy = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resTy, interleave); + return success(); + } +}; + +struct CatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + SmallVector indices(lhsTy.getShape()[0] + rhsTy.getShape()[0]); + std::iota(indices.begin(), indices.end(), 0); + rewriter.replaceOpWithNewOp(op, lhs, rhs, indices); + return success(); + } +}; + +struct ConvertElemManipOps + : public triton::impl::ConvertElemManipOpsBase { + using ConvertElemManipOpsBase::ConvertElemManipOpsBase; + + ConvertElemManipOps() : ConvertElemManipOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ElemManipOpConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add>( + typeConverter, context); + patterns.add>( + typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertElemManipOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index cadec818910b..7edf15f2e921 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -51,16 +51,10 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addDynamicallyLegalOp( [](triton::BitcastOp op) { return isa(op.getType()); }); - addIllegalOp(); - addIllegalOp(); addIllegalOp(); addIllegalOp(); - addIllegalOp(); addIllegalOp(); addIllegalOp(); - addIllegalOp(); - addIllegalOp(); - addIllegalOp(); } }; @@ -84,70 +78,6 @@ struct ConstantOpConversion : public OpConversionPattern { } }; -struct ReshapeOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - assert(isa(op.getType())); - auto loc = op.getLoc(); - auto src = rewriter.getRemappedValue(op.getSrc()); - auto srcShape = dyn_cast(src.getType()).getShape(); - auto resTy = - dyn_cast(getTypeConverter()->convertType(op.getType())); - auto dstShape = resTy.getShape(); - auto elemTy = resTy.getElementType(); - - // There are restrictions on how shape can be modified by ShapeCastOp - // when rank is changed. For now, we simply detect it and handle through - // a cast to 1D vector. Better solution may be required later. - if (canCastShape(srcShape, dstShape)) { - rewriter.replaceOpWithNewOp( - op, VectorType::get(dstShape, elemTy), src); - } else { - SmallVector tmpShape({resTy.getNumElements()}); - auto tmp = rewriter.create( - loc, VectorType::get(tmpShape, elemTy), src); - rewriter.replaceOpWithNewOp( - op, VectorType::get(dstShape, elemTy), tmp); - } - return success(); - } - -private: - bool canCastShape(ArrayRef src, ArrayRef dst) const { - if (src.size() == dst.size()) - return true; - if (src.size() > dst.size()) - return canCastShape(dst, src); - - size_t srcIdx = 0; - size_t dstIdx = 0; - while (srcIdx < src.size() && dstIdx < dst.size()) { - if (src[srcIdx] == 1) { - ++srcIdx; - } else { - // Source dim size should be a product of continuous dest dim sizes. - int64_t srcSize = src[srcIdx++]; - int64_t dstSize = dst[dstIdx++]; - while (dstSize < srcSize && dstIdx < dst.size()) - dstSize *= dst[dstIdx++]; - if (dstSize != srcSize) - return false; - } - } - - // Skip trailing 1s. - while (srcIdx < src.size() && src[srcIdx] == 1) - ++srcIdx; - while (dstIdx < dst.size() && dst[dstIdx] == 1) - ++dstIdx; - - return srcIdx == src.size() && dstIdx == dst.size(); - } -}; - struct MulhiUIOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -204,57 +134,6 @@ struct ClampFOpConversion : public OpConversionPattern { } }; -struct TransOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::TransOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto val = rewriter.getRemappedValue(op.getSrc()); - auto order = op.getOrder(); - SmallVector permutation(order.begin(), order.end()); - rewriter.replaceOpWithNewOp(op, val, permutation); - return success(); - } -}; - -struct JoinOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto lhs = rewriter.getRemappedValue(op.getLhs()); - auto rhs = rewriter.getRemappedValue(op.getRhs()); - auto interleave = rewriter.create(loc, lhs, rhs); - // JoinOp creates a new dimension, but InterleaveOp doubles the final one. - // Use ShapeCastOp to get the required shape. - auto resTy = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resTy, interleave); - return success(); - } -}; - -struct CatOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto lhs = rewriter.getRemappedValue(op.getLhs()); - auto rhs = rewriter.getRemappedValue(op.getRhs()); - auto lhsTy = dyn_cast(lhs.getType()); - auto rhsTy = dyn_cast(rhs.getType()); - SmallVector indices(lhsTy.getShape()[0] + rhsTy.getShape()[0]); - std::iota(indices.begin(), indices.end(), 0); - rewriter.replaceOpWithNewOp(op, lhs, rhs, indices); - return success(); - } -}; - struct ConvertElementwiseOps : public triton::impl::ConvertElementwiseOpsBase { using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; @@ -326,20 +205,12 @@ struct ConvertElementwiseOps patterns.add>( typeConverter, context); - patterns.add>( - typeConverter, context); - patterns.add>( - typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); - patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp index ec7c62f72f52..c7e7de72eecf 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -11,6 +11,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); pm.addPass(mlir::triton::cpu::createConvertPtrOps()); pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); + pm.addPass(mlir::triton::cpu::createConvertElemManipOps()); pm.addPass(mlir::triton::cpu::createConvertDotOp()); pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); pm.addPass(mlir::triton::cpu::createConvertReductionOp()); From 33e4a0b6920c0f8e38abf9fb0cb639563014408a Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Thu, 20 Jun 2024 16:50:35 -0700 Subject: [PATCH 031/165] [TUTORIAL] Add the non-persistent softmax and make it for CPU (#22) * [TUTORIAL] Add 02-fused-softmax with the previous non-persistent implementation * Add torch.compile cases * Preallocate output buffer for softmax tutorial --- python/tutorials/01-vector-add.py | 48 +++- python/tutorials/02-fused-softmax-cpu.py | 244 ++++++++++++++++++ .../tutorials/03-matrix-multiplication-cpu.py | 30 ++- 3 files changed, 296 insertions(+), 26 deletions(-) create mode 100644 python/tutorials/02-fused-softmax-cpu.py diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 2a660be1fd8d..a6bfb8371f85 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -26,7 +26,7 @@ DEVICE = triton.runtime.driver.active.get_active_torch_device() GPU_BLOCK_SIZE = 1024 CPU_BLOCK_SIZE = 4096 -USE_GPU = True +USE_GPU = False @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. @@ -84,14 +84,34 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: torch.manual_seed(0) size = 98432 +triton.runtime.driver.set_active_to_cpu() x = torch.rand(size, device=DEVICE) y = torch.rand(size, device=DEVICE) -output_torch = x + y -output_triton = add(x, y) -print(output_torch) -print(output_triton) -print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') +output_torch_cpu = torch.add(x, y) +output_triton_cpu = add(x, y, None, is_cpu=True) +print(output_torch_cpu) +print(output_triton_cpu) +print(f'The maximum difference between torch-cpu and triton-cpu is ' + f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') + +LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU'] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + x = x.to(DEVICE) + y = y.to(DEVICE) + output_torch = x + y + output_triton = add(x, y) + print(output_torch) + print(output_triton) + print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('yellow', '-'), ('red', '-')] # %% # Seems like we're good to go! @@ -125,27 +145,31 @@ def benchmark(size, provider): y = torch.rand(size, device=DEVICE, dtype=torch.float32) if DEVICE == 'cpu': + is_cpu = True triton.runtime.driver.set_active_to_cpu() else: + is_cpu = False triton.runtime.driver.set_active_to_gpu() quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=is_cpu) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles, is_cpu=is_cpu) elif provider == 'torch-cpu': # Note that we preallocate the output buffer here to only measure the kernel performance # without a large chunk of memory allocation. output = torch.empty_like(x) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles, - is_cpu=True) + is_cpu=is_cpu) elif provider == 'triton-cpu-single': output = torch.empty_like(x) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, + is_cpu=is_cpu) elif provider == 'triton-cpu': output = torch.empty_like(x) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, + is_cpu=is_cpu) gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/02-fused-softmax-cpu.py b/python/tutorials/02-fused-softmax-cpu.py new file mode 100644 index 000000000000..1fce4c78345f --- /dev/null +++ b/python/tutorials/02-fused-softmax-cpu.py @@ -0,0 +1,244 @@ +""" +Fused Softmax +============= + +In this tutorial, you will write a fused softmax operation that is significantly faster +than PyTorch's native op for a particular class of matrices: those whose rows can fit in +the GPU's SRAM. + +In doing so, you will learn about: + +* The benefits of kernel fusion for bandwidth-bound operations. + +* Reduction operators in Triton. + +""" + +# %% +# Motivations +# ----------- +# +# Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. +# Let us consider instead the case of a simple (numerically stabilized) softmax operation: + +import torch + +import triton +import triton.language as tl + +USE_GPU = False + + +@torch.jit.script +def naive_softmax(x): + """Compute row-wise softmax of X using native pytorch + + We subtract the maximum element in order to avoid overflows. Softmax is invariant to + this shift. + """ + # read MN elements ; write M elements + x_max = x.max(dim=1)[0] + # read MN + M elements ; write MN elements + z = x - x_max[:, None] + # read MN elements ; write MN elements + numerator = torch.exp(z) + # read MN elements ; write M elements + denominator = numerator.sum(dim=1) + # read MN + M elements ; write MN elements + ret = numerator / denominator[:, None] + # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements + return ret + + +# %% +# When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` +# requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. +# This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads +# X once and does all the necessary computations on-chip. +# Doing so would require reading and writing back only :math:`MN` bytes, so we could +# expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). +# The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically +# but, as we will see later, it is still far from ideal. + +# %% +# Compute Kernel +# -------------- +# +# Our softmax kernel works as follows: each program loads a row of the input matrix X, +# normalizes it and writes back the result to the output Y. +# +# Note that one important limitation of Triton is that each block must have a +# power-of-two number of elements, so we need to internally "pad" each row and guard the +# memory operations properly if we want to handle any possible input shapes: + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): + # The rows of the softmax are independent, so we parallelize across those + row_idx = tl.program_id(0) + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + +# %% +# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. + + +def softmax(x, y=None): + n_rows, n_cols = x.shape + # The block size is the smallest power of two greater than the number of columns in `x` + BLOCK_SIZE = triton.next_power_of_2(n_cols) + # Another trick we can use is to ask the compiler to use more threads per row by + # increasing the number of warps (`num_warps`) over which each row is distributed. + # You will see in the next tutorial how to auto-tune this value in a more natural + # way so you don't have to come up with manual heuristics yourself. + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + # Allocate output + if y is None: + y = torch.empty_like(x) + # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row of + # the input matrix + softmax_kernel[(n_rows, )]( + y, + x, + x.stride(0), + y.stride(0), + n_cols, + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return y + + +# %% +# Unit Test +# --------- + +# %% +# We make sure that we test our kernel on a matrix with an irregular number of rows and columns. +# This will allow us to verify that our padding mechanism works. + +triton.runtime.driver.set_active_to_cpu() + +torch.manual_seed(0) +x = torch.randn(1823, 781, device='cpu') +y_triton_cpu = softmax(x) +y_torch_cpu = torch.softmax(x, axis=1) +assert torch.allclose(y_triton_cpu, y_torch_cpu), (y_triton_cpu, y_torch_cpu) + +LINE_VALS = [ + 'triton-cpu-single', + 'triton-cpu', + 'torch-cpu-compile', + 'torch-cpu-jit', + 'torch-cpu-native', +] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (compile)', 'TorchCPU (jit)', 'TorchCPU (native)'] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-'), ('green', '--'), ('green', '-.')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + x = x.to('cuda') + y_triton_gpu = softmax(x) + y_torch_gpu = torch.softmax(x, axis=1) + assert torch.allclose(y_triton_gpu, y_torch_gpu), (y_triton_gpu, y_torch_gpu) + LINE_VALS += ['triton-gpu', 'torch-gpu-native', 'torch-gpu-jit'] + LINE_NAMES += ['TritonGPU', 'TorchGPU (native)', 'TorchGPU (jit)'] + LINE_STYLES += [('yellow', '-'), ('red', '-'), ('red', '--')] + +# %% +# As expected, the results are identical. + +# %% +# Benchmark +# --------- +# +# Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. +# We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=['N'], # argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 52, 2)], # different possible values for `x_name` + line_arg='provider', # argument name whose value corresponds to a different line in the plot + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel="GB/s", # label name for the y-axis + plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. + args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` + )) +def benchmark(M, N, provider): + import os + + # Currently compilation time is very long. Let's show the progress. + print(f"Running {provider} with {M} x {N}...") + + device = 'cpu' if 'cpu' in provider else 'cuda' + x = torch.randn(M, N, device=device, dtype=torch.float32) + + if device == 'cpu': + is_cpu = True + y = torch.empty_like(x) + triton.runtime.driver.set_active_to_cpu() + if 'single' in provider: + os.environ['TRITON_CPU_SINGLE_CORE'] = '1' + else: + os.unsetenv('TRITON_CPU_SINGLE_CORE') + else: + is_cpu = False + y = None + triton.runtime.driver.set_active_to_gpu() + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-cpu-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, + is_cpu=is_cpu) + if provider == 'torch-cpu-jit': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'torch-cpu-compile': + compiled = torch.compile(naive_softmax) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'triton-cpu-single': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu) + if provider == 'torch-gpu-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, + is_cpu=is_cpu) + if provider == 'torch-gpu-jit': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, is_cpu=is_cpu) + gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +benchmark.run(show_plots=True, print_data=True) + +# %% +# In the above plot, we can see that: +# - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. +# - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. +# Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index e96d04614661..937cbc652ba7 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -158,7 +158,7 @@ BLOCK_SIZE_N = 32 BLOCK_SIZE_K = 32 GROUP_SIZE_M = 8 -USE_GPU = True +USE_GPU = False @triton.jit @@ -217,9 +217,9 @@ def matmul_kernel( # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. - #TODO: Currently masked load is not supported yet. - #a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - #b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # TODO: Currently masked load is not supported yet. + # a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) a = tl.load(a_ptrs) b = tl.load(b_ptrs) # We accumulate along the K dimension. @@ -237,9 +237,9 @@ def matmul_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - #TODO: Currently masked load is not supported yet. - #c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - #tl.store(c_ptrs, c, mask=c_mask) + # TODO: Currently masked load is not supported yet. + # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + # tl.store(c_ptrs, c, mask=c_mask) tl.store(c_ptrs, c) @@ -309,9 +309,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): # We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices, # but feel free to arrange this script as you wish to benchmark any other matrix shape. -LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] -LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU'] -LINE_STYLES = [('blue', '-'), ('green', '-'), ('cyan', '-')] +LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native', 'torch-cpu-compile'] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)', 'TorchCPU (compile)'] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--'), ('green', '-')] if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() @@ -366,12 +366,14 @@ def benchmark(M, N, K, provider): b = torch.randn((K, N), device=device, dtype=torch.float32) if device == 'cpu': + c = torch.empty((M, N), device=a.device, dtype=a.dtype) triton.runtime.driver.set_active_to_cpu() if 'single' in provider: os.environ['TRITON_CPU_SINGLE_CORE'] = '1' else: os.unsetenv('TRITON_CPU_SINGLE_CORE') else: + c = None triton.runtime.driver.set_active_to_gpu() quantiles = [0.5, 0.2, 0.8] @@ -379,15 +381,15 @@ def benchmark(M, N, K, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) elif provider == 'triton-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) - elif provider == 'torch-cpu': - c = torch.empty((M, N), device=a.device, dtype=a.dtype) + elif provider == 'torch-cpu-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, is_cpu=True) + elif provider == 'torch-cpu-compile': + compiled = torch.compile(torch.matmul) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu-single': - c = torch.empty((M, N), device=a.device, dtype=a.dtype) ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) elif provider == 'triton-cpu': - c = torch.empty((M, N), device=a.device, dtype=a.dtype) ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) From 054f1f3b96c50d514a120314e5ab2a23a767c660 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 20 Jun 2024 18:50:59 -0500 Subject: [PATCH 032/165] Enable few more core tests for CPU. (#31) * Enable test_enable_fp_fusion for CPU. Signed-off-by: Ilya Enkovich * Enable test_optimize_thread_locality for CPU. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 1229a212c0ad..61cfdd8dadcc 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2781,6 +2781,7 @@ def scan_kernel(out_ptr, in_ptr, M: tl.constexpr, N: tl.constexpr): torch.testing.assert_close(ref.to(torch.int32), output, atol=0, rtol=0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op", ['sum', 'max', 'min']) @pytest.mark.parametrize("BLOCK_N", [32, 64, 128]) @@ -2822,8 +2823,8 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): BLOCK_M = 32 x = torch.randn((BLOCK_M, N), dtype=torch.float32, device=device) y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device=device) - h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N) - if not is_interpreter(): + h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N, NUM_PID_N=num_pid_n) + if not is_interpreter() and not is_cpu(): assert h.asm['ttgir'].count( '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) @@ -6483,6 +6484,7 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s # ----------------------- +@pytest.mark.cpu @pytest.mark.parametrize("enable_fp_fusion", [False, True]) @pytest.mark.parametrize("default_override", [False, True]) def test_enable_fp_fusion(enable_fp_fusion, default_override, device): @@ -6504,10 +6506,12 @@ def mul_add(data): else: h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) - if not is_cuda(): - return - found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None - assert found_fma == enable_fp_fusion + if is_cuda(): + found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None + assert found_fma == enable_fp_fusion + elif is_cpu(): + found_fma = re.search(r'vfma', h.asm["asm"].decode('utf-8')) is not None + assert found_fma == enable_fp_fusion # ----------------------- From 78851cbf0392e2d6303f4913c2755c96f6bfd93a Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 20 Jun 2024 18:51:53 -0500 Subject: [PATCH 033/165] Support tt.split for CPU. (#30) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 2 + .../TritonToTritonCPU/ConvertElemManipOps.cpp | 41 +++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 61cfdd8dadcc..7974a0b8526b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2145,6 +2145,7 @@ def kernel(X, Y, Z): np.testing.assert_equal([10, 20], to_numpy(z)) +@pytest.mark.cpu @pytest.mark.interpreter def test_split(device): @@ -2167,6 +2168,7 @@ def kernel(X, Z1, Z2, N: tl.constexpr): np.testing.assert_equal(to_numpy(z2_ref), to_numpy(z2)) +@pytest.mark.cpu @pytest.mark.interpreter def test_split_to_scalar(device): diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp index 99211ea90e41..a39a93e42446 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp @@ -48,6 +48,7 @@ class ElemManipOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -166,6 +167,45 @@ struct CatOpConversion : public OpConversionPattern { } }; +struct SplitOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcTy = cast(src.getType()); + auto resTy = getTypeConverter()->convertType(op.getType(0)); + + SmallVector results; + if (srcTy.getRank() == 1) { + results.push_back(rewriter.create(loc, src, 0)); + results.push_back(rewriter.create(loc, src, 1)); + } else { + SmallVector tmpShape({srcTy.getNumElements()}); + auto tmp = rewriter.create( + loc, VectorType::get(tmpShape, srcTy.getElementType()), src); + + SmallVector evenIndices; + SmallVector oddIndices; + for (int64_t i = 0; i < srcTy.getNumElements(); i += 2) { + evenIndices.push_back(i); + oddIndices.push_back(i + 1); + } + + Value res1 = + rewriter.create(loc, tmp, tmp, evenIndices); + Value res2 = + rewriter.create(loc, tmp, tmp, oddIndices); + results.push_back(rewriter.create(loc, resTy, res1)); + results.push_back(rewriter.create(loc, resTy, res2)); + } + rewriter.replaceOp(op, results); + return success(); + } +}; + struct ConvertElemManipOps : public triton::impl::ConvertElemManipOpsBase { using ConvertElemManipOpsBase::ConvertElemManipOpsBase; @@ -187,6 +227,7 @@ struct ConvertElemManipOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); From f13afde7defcbd635b91061dfe8e8ff84dfeafd1 Mon Sep 17 00:00:00 2001 From: RuiqiGao Date: Mon, 24 Jun 2024 22:01:32 -0700 Subject: [PATCH 034/165] [BACKEND][CPU] Make the CPU backend buildable and runnable in Mac M1. (#18) Add header for unique_ptr in CPU launcher. --- python/triton/runtime/build.py | 18 ++++++++++++++++++ third_party/cpu/backend/driver.py | 1 + 2 files changed, 19 insertions(+) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index a2072981bae8..8c97f1f0b50d 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -1,11 +1,26 @@ +import contextlib +import sys +import platform +import io import sysconfig import os import shutil import subprocess +import setuptools +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + def _build(name, src, srcdir, library_dirs, include_dirs, libraries): suffix = sysconfig.get_config_var('EXT_SUFFIX') + system = platform.system() so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) # try to avoid setuptools if possible cc = os.environ.get("CC") @@ -30,6 +45,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] + # Use dynamic lookup to load Python library on Mac + if system == "Darwin": + cc_cmd += ["-undefined", "dynamic_lookup"] cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 126e41d3bd7a..ebce4229d7af 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -143,6 +143,7 @@ def format_of(ty): #include #include #include +#include #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION #include From 6122eafef0506d96ca94fe041dfb94365f8e90f0 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 25 Jun 2024 01:05:17 -0500 Subject: [PATCH 035/165] [CPU] Add conversion for unsupported BF16 ops via target-specific stage (#27) * Remove unused code. Signed-off-by: Ilya Enkovich * Fma is always allowed on CPU. Signed-off-by: Ilya Enkovich * Add unsupported op conversions for BF16 type. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- bin/RegisterTritonDialects.h | 2 + .../triton/Dialect/TritonCPU/CMakeLists.txt | 1 - .../TritonCPU/Transforms/CMakeLists.txt | 3 - .../Dialect/TritonCPU/Transforms/Passes.h | 16 -- .../Dialect/TritonCPU/Transforms/Passes.td | 6 - .../Transforms/TritonCPUConversion.h | 31 --- lib/Conversion/CMakeLists.txt | 3 - lib/Conversion/TritonCPUToLLVM/CMakeLists.txt | 20 -- .../TritonCPUToLLVM/CPUTargetInfo.cpp | 49 ----- .../TritonCPUToLLVM/ControlFlowOpToLLVM.cpp | 37 ---- .../TritonCPUToLLVM/FuncOpToLLVM.cpp | 54 ----- .../TritonCPUToLLVM/PrintOpToLLVM.cpp | 131 ----------- .../TritonCPUToLLVM/SPMDOpToLLVM.cpp | 39 ---- .../TritonCPUToLLVM/TritonCPUToLLVM.cpp | 117 ---------- .../TritonCPUToLLVM/TypeConverter.cpp | 31 --- .../TritonToTritonCPU/CMakeLists.txt | 15 -- .../TritonToTritonCPU/TritonCPUConversion.cpp | 108 ---------- .../TritonToTritonCPUPass.cpp | 41 ---- lib/Dialect/TritonCPU/CMakeLists.txt | 1 - .../TritonCPU/Transforms/CMakeLists.txt | 13 -- python/setup.py | 2 +- python/src/llvm.cc | 15 ++ python/test/unit/language/test_core.py | 3 - third_party/cpu/CMakeLists.txt | 2 +- third_party/cpu/backend/compiler.py | 17 ++ third_party/cpu/include/CMakeLists.txt | 1 + .../TritonCPUTransforms/CMakeLists.txt | 3 + .../cpu/include/TritonCPUTransforms/Passes.h | 32 +++ .../cpu/include/TritonCPUTransforms/Passes.td | 39 ++++ third_party/cpu/lib/CMakeLists.txt | 1 + .../lib/TritonCPUTransforms/CMakeLists.txt | 7 + .../ConvertUnsupportedOps.cpp | 204 ++++++++++++++++++ .../DecomposeFpConversions.cpp | 81 +++++++ .../cpu/lib/TritonCPUTransforms/OptCommon.h | 46 ++++ third_party/cpu/triton_cpu.cc | 7 + 35 files changed, 457 insertions(+), 721 deletions(-) delete mode 100644 include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt delete mode 100644 include/triton/Dialect/TritonCPU/Transforms/Passes.h delete mode 100644 include/triton/Dialect/TritonCPU/Transforms/Passes.td delete mode 100644 include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h delete mode 100644 lib/Conversion/TritonCPUToLLVM/CMakeLists.txt delete mode 100644 lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp delete mode 100644 lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp delete mode 100644 lib/Conversion/TritonToTritonCPU/CMakeLists.txt delete mode 100644 lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp delete mode 100644 lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp delete mode 100644 lib/Dialect/TritonCPU/Transforms/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt create mode 100644 third_party/cpu/include/TritonCPUTransforms/Passes.h create mode 100644 third_party/cpu/include/TritonCPUTransforms/Passes.td create mode 100644 third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt create mode 100644 third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp create mode 100644 third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp create mode 100644 third_party/cpu/lib/TritonCPUTransforms/OptCommon.h diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index ca922e824793..c67987e8159e 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -18,6 +18,7 @@ #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" #include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "cpu/include/TritonCPUTransforms/Passes.h" #include "cpu/include/TritonToTritonCPU/Passes.h" #include "nvidia/include/NVGPUToLLVM/Passes.h" #include "nvidia/include/TritonNVIDIAGPUToLLVM/Passes.h" @@ -74,6 +75,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { // CPU passes mlir::triton::cpu::registerTritonToTritonCPUPasses(); mlir::triton::cpu::registerTritonToTritonCPUPipeline(); + mlir::triton::cpu::registerTritonCPUTransformsPasses(); mlir::triton::cpu::registerTritonCPUToLLVMPasses(); mlir::triton::cpu::registerTritonCPUToLLVMPipeline(); diff --git a/include/triton/Dialect/TritonCPU/CMakeLists.txt b/include/triton/Dialect/TritonCPU/CMakeLists.txt index 9f57627c321f..f33061b2d87c 100644 --- a/include/triton/Dialect/TritonCPU/CMakeLists.txt +++ b/include/triton/Dialect/TritonCPU/CMakeLists.txt @@ -1,2 +1 @@ add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt b/include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt deleted file mode 100644 index 6aa946f64932..000000000000 --- a/include/triton/Dialect/TritonCPU/Transforms/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name TritonCPU) -add_public_tablegen_target(TritonCPUTransformsIncGen) diff --git a/include/triton/Dialect/TritonCPU/Transforms/Passes.h b/include/triton/Dialect/TritonCPU/Transforms/Passes.h deleted file mode 100644 index f31e47317080..000000000000 --- a/include/triton/Dialect/TritonCPU/Transforms/Passes.h +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef TRITON_DIALECT_TRITONCPU_TRANSFORMS_PASSES_H_ -#define TRITON_DIALECT_TRITONCPU_TRANSFORMS_PASSES_H_ - -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace triton { -namespace cpu {} // namespace cpu -} // namespace triton - -/// Generate the code for registering passes. -#define GEN_PASS_REGISTRATION -#include "triton/Dialect/TritonCPU/Transforms/Passes.h.inc" - -} // namespace mlir -#endif diff --git a/include/triton/Dialect/TritonCPU/Transforms/Passes.td b/include/triton/Dialect/TritonCPU/Transforms/Passes.td deleted file mode 100644 index a1d5271ee6e7..000000000000 --- a/include/triton/Dialect/TritonCPU/Transforms/Passes.td +++ /dev/null @@ -1,6 +0,0 @@ -#ifndef TRITONCPU_PASSES -#define TRITONCPU_PASSES - -include "mlir/Pass/PassBase.td" - -#endif diff --git a/include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h b/include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h deleted file mode 100644 index 01c24e19c60e..000000000000 --- a/include/triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h +++ /dev/null @@ -1,31 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// Defines utilities to use while converting to the TritonCPU dialect. -// -//===----------------------------------------------------------------------===// - -#ifndef TRITON_DIALECT_TRITONCPU_TRANSFORMS_TRITONCPUCONVERSION_H_ -#define TRITON_DIALECT_TRITONCPU_TRANSFORMS_TRITONCPUCONVERSION_H_ - -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { - -class TritonCPUTypeConverter : public TypeConverter { -public: - TritonCPUTypeConverter(MLIRContext *context); - -private: - MLIRContext *context; -}; - -class TritonCPUConversionTarget : public ConversionTarget { - -public: - explicit TritonCPUConversionTarget(MLIRContext &ctx, - TritonCPUTypeConverter &typeConverter); -}; - -} // namespace mlir - -#endif // TRITON_DIALECT_TRITONCPU_TRANSFORMS_TRITONCPUCONVERSION_H_ diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 426b22a42ef6..143a4375a811 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,5 +1,2 @@ -# TODO(minjang): I will remove these scratches soon. -# add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) -# add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) diff --git a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt deleted file mode 100644 index db507557fb22..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -add_triton_library(TritonCPUToLLVM - ControlFlowOpToLLVM.cpp - CPUTargetInfo.cpp - FuncOpToLLVM.cpp - PrintOpToLLVM.cpp - SPMDOpToLLVM.cpp - TypeConverter.cpp - TritonCPUToLLVM.cpp - - DEPENDS - TritonCPUConversionPassIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - TritonAnalysis - TritonIR - TritonCPUIR - TritonCPUTransforms -) diff --git a/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp b/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp deleted file mode 100644 index 8dd050b80bbf..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/CPUTargetInfo.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" - -namespace { -LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) { - auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); - StringRef funcName("printf"); - Operation *funcOp = moduleOp.lookupSymbol(funcName); - if (funcOp) - return cast(*funcOp); - - auto *context = rewriter.getContext(); - - // int printf(char* format, ...) - SmallVector argsType{ptr_ty(context)}; - auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true); - - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - return rewriter.create(UnknownLoc::get(context), funcName, - funcType); -} -} // namespace - -namespace mlir::triton::cpu { - -Value CPUTargetInfo::programId(ConversionPatternRewriter &rewriter, - Location loc, LLVM::LLVMFuncOp funcOp, - int axis) const { - assert(axis >= 0 && axis < 3); - - // program_id for CPU is provided as function arguments. The last three - // arguments are __grid0 to __grid2 of i32. - assert(funcOp && funcOp.getArguments().size() >= 3); - return funcOp.getArgument(funcOp.getArguments().size() - 3 + axis); -} - -void CPUTargetInfo::printf(ConversionPatternRewriter &rewriter, - Value formatStrStart, int /*formatStrByteCount*/, - ValueRange args) const { - auto loc = UnknownLoc::get(rewriter.getContext()); - SmallVector formatStrAndArgs{formatStrStart}; - for (auto arg : args) { - formatStrAndArgs.push_back(arg); - } - call(getPrintfDeclaration(rewriter), formatStrAndArgs); -} -} // namespace mlir::triton::cpu diff --git a/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp deleted file mode 100644 index a270c0d60845..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/ControlFlowOpToLLVM.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" -#include "llvm/Support/ErrorHandling.h" - -namespace { - -using namespace mlir; -using namespace mlir::triton; - -struct ReturnOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto funcOp = op->getParentOfType(); - if (funcOp->hasAttr("cpu.kernel")) { - if (op.getNumOperands() > 0) { - return rewriter.notifyMatchFailure( - op, "Kernel functions do not support return with operands"); - } - rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), - op->getAttrs()); - } else { - llvm_unreachable("Not implemented"); - } - return success(); - } -}; - -} // namespace - -void mlir::triton::cpu::populateControlFlowOpToLLVMPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); -} diff --git a/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp deleted file mode 100644 index 9ecd470345ad..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/FuncOpToLLVM.cpp +++ /dev/null @@ -1,54 +0,0 @@ -#include "mlir/Support/LogicalResult.h" -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" - -namespace mlir { -FailureOr -convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp, - ConversionPatternRewriter &rewriter, - const LLVMTypeConverter &converter); -} - -namespace { - -using namespace mlir; -using namespace mlir::triton; - -struct FuncOpConversion : public ConvertOpToLLVMPattern { - FuncOpConversion(LLVMTypeConverter &converter, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit) {} - - LogicalResult - matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (!LLVM::isKernel(funcOp)) { - llvm_unreachable("Not implemented"); - } - - LLVM::LLVMFuncOp newFuncOp = - *mlir::convertFuncOpToLLVMFuncOp(funcOp, rewriter, *getTypeConverter()); - if (!newFuncOp) { - return failure(); - } - - auto ctx = funcOp->getContext(); - if (LLVM::isKernel(funcOp)) { - // Set an attribute to indicate this function is a kernel entry. - newFuncOp->setAttr("cpu.kernel", - rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); - } else { - llvm_unreachable("Not implemented"); - } - - rewriter.eraseOp(funcOp); - return success(); - } -}; - -} // namespace - -void mlir::triton::cpu::populateFuncOpConversionPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); -} diff --git a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp deleted file mode 100644 index b424cf8e37b7..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/PrintOpToLLVM.cpp +++ /dev/null @@ -1,131 +0,0 @@ -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/PatternMatch.h" -#include "triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h" -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -namespace { - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -struct PrintOpConversion : public ConvertOpToLLVMPattern { - explicit PrintOpConversion(LLVMTypeConverter &typeConverter, - const CPUTargetInfo &targetInfo, - PatternBenefit benefit) - : mlir::ConvertOpToLLVMPattern(typeConverter, benefit), - targetInfo(targetInfo) {} - - LogicalResult - matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - - auto getPid = [&](int axis) { - return targetInfo.programId( - rewriter, loc, op->getParentOfType(), axis); - }; - SmallVector values = {getPid(0), getPid(1), getPid(2)}; - - std::string formatStr; - llvm::raw_string_ostream os(formatStr); - os << "pid (" << getFormatSubstr(values[0]) << ", " - << getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2]) - << ")" << op.getPrefix(); - - for (size_t i = 0; i < op.getNumOperands(); i++) { - auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); - if (dyn_cast(op.getOperand(i).getType())) { - llvm_unreachable("Not implemented for tensor types"); - } - - // Only support scalars for now. - assert(elems.size() == 1); - if (i != 0) { - os << ", "; - } - os << getFormatSubstr(elems[0]); - values.push_back(elems[0]); - } - - llPrintf(formatStr, values, rewriter); - rewriter.eraseOp(op); - return success(); - } - - // TODO: This code is the same as the GPU-backend code. Consider refactoring. - std::string getFormatSubstr(Value value, bool hex = false, - std::optional width = std::nullopt) const { - Type type = value.getType(); - if (isa(type)) { - return "%p"; - } - // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the - // type (so 4 for fp16, 8 for int32, 16 for int64). - if (hex) { - // Ignore `width` for `hex` values, pad to typeWidth. - std::string ret = - "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); - if (type.getIntOrFloatBitWidth() > 32) { - ret += "ll"; - } - ret += "x"; - return ret; - } - - std::string prefix = "%"; - if (width.has_value()) { - prefix += std::to_string(*width); - } else if (hex) { - prefix += "0"; - prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); - } - - if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { - return prefix + "f"; - } else if (type.isSignedInteger()) { - if (type.getIntOrFloatBitWidth() == 64) - return prefix + "lli"; - else - return prefix + "i"; - } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { - if (type.getIntOrFloatBitWidth() == 64) - return prefix + "llu"; - else - return prefix + "u"; - } - assert(false && "not supported type"); - return ""; - } - - Value llPrintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter, - int *formatStrByteCount = nullptr) const { - assert(!msg.empty() && "printf with empty string not supported"); - llvm::SmallString<64> msgNewline(msg); - msgNewline.push_back('\n'); - msgNewline.push_back('\0'); - Value msgValue = - LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), - rewriter, "printfFormat_", msgNewline); - targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args); - if (formatStrByteCount) - *formatStrByteCount = msgNewline.size_in_bytes(); - return msgValue; - } - -protected: - const CPUTargetInfo &targetInfo; -}; - -} // namespace - -void mlir::triton::cpu::populatePrintOpToLLVMPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - const CPUTargetInfo &targetInfo, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, benefit); -} diff --git a/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp deleted file mode 100644 index 65fef7a7d0d5..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/SPMDOpToLLVM.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/Utility.h" - -namespace { - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -struct GetProgramIdOpConversion - : public ConvertOpToLLVMPattern { - explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter, - const CPUTargetInfo &targetInfo, - PatternBenefit benefit = 1) - : ConvertOpToLLVMPattern(typeConverter, benefit), - targetInfo(targetInfo) {} - - LogicalResult - matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value programId = targetInfo.programId( - rewriter, op->getLoc(), op->getParentOfType(), - op.getAxisAsInt()); - rewriter.replaceOp(op, programId); - return success(); - } - -private: - const CPUTargetInfo &targetInfo; -}; - -} // namespace - -void mlir::triton::cpu::populateSPMDOpToLLVMPattern( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - const CPUTargetInfo &targetInfo, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, benefit); -} diff --git a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp b/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp deleted file mode 100644 index cb15f87ee206..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/TritonCPUToLLVM.cpp +++ /dev/null @@ -1,117 +0,0 @@ -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/LLVMCommon/VectorPattern.h" -#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonCPUToLLVM/Passes.h" -#include "triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h" -#include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -namespace mlir { -namespace triton { -#define GEN_PASS_DEF_CONVERTTRITONCPUTOLLVM -#include "triton/Conversion/TritonCPUToLLVM/Passes.h.inc" -} // namespace triton -} // namespace mlir - -using namespace mlir; - -namespace { - -class TritonLLVMFunctionConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addLegalDialect(); - addLegalOp(); - } -}; - -class TritonLLVMConversionTarget : public ConversionTarget { -public: - explicit TritonLLVMConversionTarget(MLIRContext &ctx) - : ConversionTarget(ctx) { - addLegalDialect(); - addLegalDialect(); - addIllegalDialect(); - addIllegalDialect(); - addLegalOp(); - } -}; - -struct ConvertTritonCPUToLLVM - : public triton::impl::ConvertTritonCPUToLLVMBase { - using ConvertTritonCPUToLLVMBase< - ConvertTritonCPUToLLVM>::ConvertTritonCPUToLLVMBase; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - ConvertTritonCPUToLLVM() : ConvertTritonCPUToLLVMBase() {} - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ModuleOp mod = getOperation(); - mlir::LowerToLLVMOptions option(context); - option.overrideIndexBitwidth(32); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMConversionTarget convTarget(*context); - - // Lower functions - { - mlir::LowerToLLVMOptions option(context); - TritonCPUToLLVMTypeConverter typeConverter(context, option); - TritonLLVMFunctionConversionTarget funcTarget(*context); - RewritePatternSet funcPatterns(context); - mlir::triton::cpu::populateFuncOpConversionPattern( - typeConverter, funcPatterns, - mlir::triton::cpu::patternBenefitDefault); - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, - funcPatterns); - if (failed( - applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) - return signalPassFailure(); - } - - RewritePatternSet patterns(context); - mlir::triton::cpu::CPUTargetInfo targetInfo; - int benefit = - mlir::triton::cpu::patternBenefitPrioritizeOverLLVMConversions; - mlir::triton::cpu::populateControlFlowOpToLLVMPattern(typeConverter, - patterns, benefit); - mlir::triton::cpu::populatePrintOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); - mlir::triton::cpu::populateSPMDOpToLLVMPattern(typeConverter, patterns, - targetInfo, benefit); - - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) - return signalPassFailure(); - } -}; - -} // anonymous namespace - -namespace mlir { -namespace triton { - -std::unique_ptr> createConvertTritonCPUToLLVMPass() { - return std::make_unique(); -} - -} // namespace triton -} // namespace mlir diff --git a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp deleted file mode 100644 index 72ef796fdabb..000000000000 --- a/lib/Conversion/TritonCPUToLLVM/TypeConverter.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include "triton/Conversion/TritonCPUToLLVM/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Conversion/MLIRTypes.h" -#include "llvm/Support/ErrorHandling.h" - -using namespace mlir; -using namespace mlir::triton; - -TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( - MLIRContext *ctx, LowerToLLVMOptions &option, - const DataLayoutAnalysis *analysis) - : LLVMTypeConverter(ctx, option, analysis) { - addConversion([&](triton::PointerType type) -> std::optional { - return convertTritonPointerType(type); - }); - - // Internally store bfloat16 as int16 - addConversion([&](BFloat16Type type) -> std::optional { - return IntegerType::get(type.getContext(), 16); - }); -} - -Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( - triton::PointerType type) { - auto ctx = type.getContext(); - auto pointeeType = type.getPointeeType(); - if (isa(pointeeType)) { - llvm_unreachable("Not implemented"); - } - return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()); -} diff --git a/lib/Conversion/TritonToTritonCPU/CMakeLists.txt b/lib/Conversion/TritonToTritonCPU/CMakeLists.txt deleted file mode 100644 index f1b612b9c291..000000000000 --- a/lib/Conversion/TritonToTritonCPU/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_triton_library(TritonToTritonCPU - TritonCPUConversion.cpp - TritonToTritonCPUPass.cpp - - DEPENDS - TritonConversionToCPUPassIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MLIRTransforms - TritonIR - TritonCPUIR - TritonCPUTransforms -) diff --git a/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp b/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp deleted file mode 100644 index 97948404bdbf..000000000000 --- a/lib/Conversion/TritonToTritonCPU/TritonCPUConversion.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include "triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h" - -#include "mlir/IR/IRMapping.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" -#include -#include - -using namespace mlir; -using namespace mlir::triton::cpu; - -// -// TypeConverter -// -TritonCPUTypeConverter::TritonCPUTypeConverter(MLIRContext *context) - : context(context) { - addConversion([](Type type) { return type; }); - - // Add encoding for tensor - addConversion([this](RankedTensorType tensorType) -> RankedTensorType { - // TODO: - return tensorType; - }); - - // Add encoding for tensor pointer - addConversion([this](triton::PointerType ptrType) -> triton::PointerType { - // Check whether tensor pointer `tt.ptr>` - auto pointeeTensorType = - dyn_cast(ptrType.getPointeeType()); - if (pointeeTensorType == nullptr) - return ptrType; - - // Add layout into the tensor - auto convertedTensorType = convertType(pointeeTensorType); - return triton::PointerType::get(convertedTensorType, - ptrType.getAddressSpace()); - }); - - // - // Materializations - // - // This will be called when (newArgType != origArgType) - // This will create newArg, and map(origArg, newArg) - addArgumentMaterialization([&](OpBuilder &builder, - RankedTensorType tensorType, ValueRange inputs, - Location loc) -> std::optional { - llvm_unreachable("Argument rematerialization should not happen in Triton " - "-> TritonCPU conversion"); - return std::nullopt; - }); - - // If the origValue still has live user(s), use this to - // convert origValue to newValue - addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, - Location loc) -> std::optional { - llvm_unreachable("Source rematerialization should not happen in Triton -> " - "TritonCPU Conversion"); - return std::nullopt; - }); - - // This will be called when (desiredType != newOperandType) - // where, desiredType = typeConverter->convertType(origType) - // NOTE: only for remapped values. - addTargetMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, - ValueRange inputs, Location loc) { - llvm_unreachable("Source rematerialization should not happen in Triton -> " - "TritonCPU Conversion"); - return std::nullopt; - }); -} - -// -// TritonCPUConversion -// -TritonCPUConversionTarget::TritonCPUConversionTarget( - MLIRContext &context, TritonCPUTypeConverter &typeConverter) - : ConversionTarget(context) { - // TODO: we should also verify ops of TritonCPUDialect - addLegalDialect(); - - // Some ops from SCF are illegal - addIllegalOp(); - - addDynamicallyLegalDialect([&](Operation *op) { - bool hasLegalRegions = true; - for (auto ®ion : op->getRegions()) { - hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion); - } - if (hasLegalRegions && typeConverter.isLegal(op)) { - return true; - } - return false; - }); - - // We have requirements for the data layouts - addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { - Attribute aEncoding = - cast(dotOp.getA().getType()).getEncoding(); - Attribute bEncoding = - cast(dotOp.getB().getType()).getEncoding(); - // TODO: - return false; - }); -} diff --git a/lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp b/lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp deleted file mode 100644 index 44c41636a3f3..000000000000 --- a/lib/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include "triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h" - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "triton/Analysis/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" -#include "triton/Dialect/TritonCPU/Transforms/TritonCPUConversion.h" -#include "llvm/ADT/APSInt.h" -#include - -#define GEN_PASS_CLASSES -#include "triton/Conversion/TritonToTritonCPU/Passes.h.inc" - -namespace { - -using namespace mlir; -using namespace mlir::triton; -using namespace mlir::triton::cpu; - -class ConvertTritonToTritonCPU - : public ConvertTritonToTritonCPUBase { -public: - ConvertTritonToTritonCPU() = default; - - void runOnOperation() override { - // TODO: - } -}; - -} // namespace - -std::unique_ptr> -mlir::triton::createConvertTritonToTritonCPUPass() { - return std::make_unique<::ConvertTritonToTritonCPU>(); -} diff --git a/lib/Dialect/TritonCPU/CMakeLists.txt b/lib/Dialect/TritonCPU/CMakeLists.txt index 9f57627c321f..f33061b2d87c 100644 --- a/lib/Dialect/TritonCPU/CMakeLists.txt +++ b/lib/Dialect/TritonCPU/CMakeLists.txt @@ -1,2 +1 @@ add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/lib/Dialect/TritonCPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonCPU/Transforms/CMakeLists.txt deleted file mode 100644 index 1714215b9434..000000000000 --- a/lib/Dialect/TritonCPU/Transforms/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_triton_library(TritonCPUTransforms - - DEPENDS - TritonCPUTransformsIncGen - - LINK_LIBS PUBLIC - MLIRTransforms - MLIRTransformUtils - TritonAnalysis - TritonIR - TritonCPUIR - MLIRTransformUtils -) diff --git a/python/setup.py b/python/setup.py index 886693762e3c..ec32ea11a5b8 100644 --- a/python/setup.py +++ b/python/setup.py @@ -783,7 +783,7 @@ def get_git_version_suffix(): "tests": [ "autopep8", "isort", - "numpy", + "numpy<2.0.0", "pytest", "pytest-forked", "pytest-xdist", diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 1f7f9b03d676..ba45fbb51462 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -551,6 +551,21 @@ void init_triton_llvm(py::module &&m) { } } }); + + m.def("get_cpu_tripple", []() { return llvm::sys::getProcessTriple(); }); + + m.def("get_cpu_name", []() { return llvm::sys::getHostCPUName().str(); }); + + m.def("get_cpu_features", []() { + auto features = llvm::sys::getHostCPUFeatures(); + + std::set res; + for (auto &f : features) { + if (f.second) + res.insert(f.first().str()); + } + return res; + }); } void triton_stacktrace_signal_handler(void *) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 7974a0b8526b..d6b58a6848d2 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6511,9 +6511,6 @@ def mul_add(data): if is_cuda(): found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None assert found_fma == enable_fp_fusion - elif is_cpu(): - found_fma = re.search(r'vfma', h.asm["asm"].decode('utf-8')) is not None - assert found_fma == enable_fp_fusion # ----------------------- diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 1b08addbc9b7..b107e2434e1e 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -3,6 +3,6 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) - add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM) + add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm) endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 0a98532eceba..6a4e3d08535c 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -43,6 +43,9 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) self.binary_ext = "asm" + self.cpu_arch = llvm.get_cpu_tripple().split("-")[0] + self.cpu_name = llvm.get_cpu_name() + self.cpu_features = llvm.get_cpu_features() def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -86,6 +89,19 @@ def make_ttcir(mod, metadata, opt): metadata["cluster_dims"] = (opt.cluster_dims[0], opt.cluster_dims[1], opt.cluster_dims[2]) return mod + def make_tttcir(self, mod, metadata, opt): + # TTCIR -> Target TTCIR + pm = ir.pass_manager(mod.context) + pm.enable_debug() + if self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features: + cpu.passes.ttcpuir.add_convert_unsupported_ops(pm) + cpu.passes.ttcpuir.add_decompose_fp_conversions(pm) + passes.common.add_cse(pm) + passes.common.add_symbol_dce(pm) + passes.common.add_canonicalizer(pm) + pm.run(mod) + return mod + @staticmethod def make_llir(src, metadata, options): # warp-specialization mutates num_warps @@ -144,6 +160,7 @@ def make_asm(src, metadata, options): def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) + stages["tttcir"] = lambda src, metadata: self.make_tttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) stages["asm"] = lambda src, metadata: self.make_asm(src, metadata, options) diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt index fc9a19e52b0d..b4c91e794072 100644 --- a/third_party/cpu/include/CMakeLists.txt +++ b/third_party/cpu/include/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonCPUTransforms) add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt new file mode 100644 index 000000000000..cb2cb234172d --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/CMakeLists.txt @@ -0,0 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUTransforms) +add_public_tablegen_target(TritonCPUTransformsPassIncGen) diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h new file mode 100644 index 000000000000..035122aa98fb --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -0,0 +1,32 @@ +#ifndef TritonCPUTransforms_CONVERSION_PASSES_H +#define TritonCPUTransforms_CONVERSION_PASSES_H + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include + +namespace mlir { + +class ModuleOp; +template class OperationPass; + +namespace triton { +namespace cpu { + +#define GEN_PASS_DECL +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" + +std::unique_ptr> createConvertUnsupportedOps(); +std::unique_ptr> createDecomposeFpConversions(); + +#define GEN_PASS_REGISTRATION +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" + +} // namespace cpu +} // namespace triton + +} // namespace mlir + +#endif diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td new file mode 100644 index 000000000000..0eb4910394fd --- /dev/null +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -0,0 +1,39 @@ +#ifndef TRITONCPUOPT_CONVERSION_PASSES +#define TRITONCPUOPT_CONVERSION_PASSES + +include "mlir/Pass/PassBase.td" + +def ConvertUnsupportedOps : Pass<"triton-cpu-add-casts-for-unsupported-ops", "mlir::ModuleOp"> { + let summary = "Convert operations on unsupported types."; + let description = [{ + This pass converts various operations on data types that are not supported + by the target natively. Operations are converted to a supported data type + with casts added for inputs and the result. + }]; + // TODO: add options to specify which operations to convert. + let constructor = "mlir::triton::cpu::createConvertUnsupportedOps()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def DecomposeFpConversions : Pass<"triton-cpu-decompose-fp-conversions", "mlir::ModuleOp"> { + let summary = "Decompose fp conversion ops."; + let description = [{ + This pass is used for targets lacking native instructions to convert FP + vectors. By default, LLVM would decompose them using scalar FP conversion + intrinsics. This pass transforms such conversions into vector code + instead. + }]; + // TODO: add options to specify which FP conversions to decompose. + let constructor = "mlir::triton::cpu::createDecomposeFpConversions()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +#endif diff --git a/third_party/cpu/lib/CMakeLists.txt b/third_party/cpu/lib/CMakeLists.txt index 1db64c58ec20..fad51ab86ea9 100644 --- a/third_party/cpu/lib/CMakeLists.txt +++ b/third_party/cpu/lib/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Analysis) add_subdirectory(TritonCPUToLLVM) +add_subdirectory(TritonCPUTransforms) add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt new file mode 100644 index 000000000000..5a52aa7e86b6 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -0,0 +1,7 @@ +add_triton_library(TritonCPUTransforms + ConvertUnsupportedOps.cpp + DecomposeFpConversions.cpp + + DEPENDS + TritonCPUTransformsPassIncGen +) diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp new file mode 100644 index 000000000000..dec4970d1ccd --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -0,0 +1,204 @@ +#include "OptCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTUNSUPPORTEDOPS +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +template +struct ConvertBf16ToFp32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + // TODO: support mixed-type ops? + if (!isAllBf16(op->getOperandTypes()) || !isAllBf16(op->getResultTypes())) + return failure(); + + Location loc = op.getLoc(); + OperationState newState(loc, OpT::getOperationName()); + // Convert operands to fp32 and generate fp32 op. + for (auto operand : op->getOperands()) { + Value newOperand = rewriter.create( + loc, toFp32(operand.getType()), operand); + newState.operands.push_back(newOperand); + } + newState.types = toFp32(op->getResultTypes()); + newState.attributes = op->getAttrs(); + auto newOp = rewriter.create(newState); + + // Convert op results back to Bf16 + SmallVector results; + for (auto res : llvm::enumerate(newOp->getResults())) + results.push_back(rewriter.create( + loc, op->getResult(res.index()).getType(), res.value())); + rewriter.replaceOp(op, results); + + return success(); + } + + bool isAllBf16(TypeRange types) const { + return std::all_of(types.begin(), types.end(), + [this](auto ty) { return isBf16(ty); }); + } + + SmallVector toFp32(TypeRange types) const { + SmallVector res; + for (auto ty : types) + res.push_back(::toFp32(ty)); + return res; + } +}; + +template +struct ConvertIToBf16ToFp32 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value fp32Val = + rewriter.create(loc, toFp32(op.getType()), op.getOperand()); + Value res = rewriter.create(loc, op.getType(), fp32Val); + rewriter.replaceOp(op, res); + return success(); + } +}; + +Value convertMemRefToI16(Value memRef, PatternRewriter &rewriter) { + // Memory references for masked operations are always built + // with PtrToMemRefOp. + auto def = memRef.getDefiningOp(); + assert(def); + auto insPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(def); + MemRefType memRefTy = cast(memRef.getType()); + Type newMemRefTy = + MemRefType::get(memRefTy.getShape(), rewriter.getI16Type(), + memRefTy.getLayout(), memRefTy.getMemorySpace()); + Value res = rewriter.create(memRef.getLoc(), newMemRefTy, + def.getSrc()); + rewriter.restoreInsertionPoint(insPoint); + return res; +} + +struct ConvertBf16MaskedLoadOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedLoadOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value newBase = convertMemRefToI16(op.getBase(), rewriter); + Value newPassThru = rewriter.create( + loc, toInt16(op.getPassThru().getType()), op.getPassThru()); + Value intVal = rewriter.create( + loc, toInt16(op.getType()), newBase, op.getIndices(), op.getMask(), + newPassThru); + Value res = rewriter.create(loc, op.getType(), intVal); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertBf16MaskedStoreOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedStoreOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getValueToStore().getType())) + return failure(); + + Location loc = op.getLoc(); + Value newBase = convertMemRefToI16(op.getBase(), rewriter); + Value intVal = rewriter.create( + loc, toInt16(op.getValueToStore().getType()), op.getValueToStore()); + rewriter.replaceOpWithNewOp( + op, newBase, op.getIndices(), op.getMask(), intVal); + return success(); + } +}; + +struct ConvertBf16Abs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::AbsFOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType()) || !isBf16(op.getOperand().getType())) + return failure(); + + Location loc = op.getLoc(); + Value src = op.getOperand(); + Value intSrc = + rewriter.create(loc, toInt16(op.getType()), src); + TypedAttr maskAttr = rewriter.getI16IntegerAttr(0x7fff); + if (auto vecTy = dyn_cast(intSrc.getType())) + maskAttr = SplatElementsAttr::get(vecTy, maskAttr); + Value mask = rewriter.create(loc, maskAttr); + Value res = rewriter.create(loc, intSrc, mask); + res = rewriter.create(loc, op.getType(), res); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertUnsupportedOps + : public triton::impl::ConvertUnsupportedOpsBase { + using ConvertUnsupportedOpsBase::ConvertUnsupportedOpsBase; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add(context); + patterns.add(context); + + patterns.add(context); + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertUnsupportedOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp new file mode 100644 index 000000000000..a82958b3ae2b --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -0,0 +1,81 @@ +#include "OptCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DECOMPOSEFPCONVERSIONS +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +struct Fp32ToBf16Conversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const override { + Value src = op.getIn(); + if (!isBf16(op.getType()) || !isFp32(src.getType())) + return failure(); + + Location loc = op.getLoc(); + Value i32Src = + rewriter.create(loc, toInt32(src.getType()), src); + TypedAttr shiftValAttr = rewriter.getI32IntegerAttr(16); + if (auto vecTy = dyn_cast(i32Src.getType())) + shiftValAttr = SplatElementsAttr::get(vecTy, shiftValAttr); + Value shiftedSrc = rewriter.create( + loc, i32Src, rewriter.create(loc, shiftValAttr)); + Value i16Res = rewriter.create(loc, toInt16(src.getType()), + shiftedSrc); + Value res = rewriter.create(loc, op.getType(), i16Res); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct DecomposeFpConversions + : public triton::impl::DecomposeFpConversionsBase { + using DecomposeFpConversionsBase::DecomposeFpConversionsBase; + + DecomposeFpConversions() : DecomposeFpConversionsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createDecomposeFpConversions() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h b/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h new file mode 100644 index 000000000000..a9fe054b8ede --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h @@ -0,0 +1,46 @@ +#ifndef TRITONCPU_CONVERSION_TRITONCPUOPT_OPTCOMMON_H +#define TRITONCPU_CONVERSION_TRITONCPUOPT_OPTCOMMON_H + +#include "mlir/IR/BuiltinTypes.h" + +namespace mlir { +namespace triton { +namespace cpu { + +inline bool isTyOrVectorOf(mlir::Type ty, mlir::Type elemTy) { + if (auto vecTy = dyn_cast(ty)) + return vecTy.getElementType() == elemTy; + return ty == elemTy; +} + +inline bool isBf16(mlir::Type ty) { + return isTyOrVectorOf(ty, mlir::BFloat16Type::get(ty.getContext())); +} + +inline bool isFp32(mlir::Type ty) { + return isTyOrVectorOf(ty, mlir::Float32Type::get(ty.getContext())); +} + +inline mlir::Type toTyOrVectorOf(mlir::Type ty, mlir::Type elemTy) { + if (auto vecTy = dyn_cast(ty)) + return vecTy.cloneWith(std::nullopt, elemTy); + return elemTy; +} + +inline mlir::Type toInt16(mlir::Type ty) { + return toTyOrVectorOf(ty, mlir::IntegerType::get(ty.getContext(), 16)); +} + +inline mlir::Type toInt32(mlir::Type ty) { + return toTyOrVectorOf(ty, mlir::IntegerType::get(ty.getContext(), 32)); +} + +inline mlir::Type toFp32(mlir::Type ty) { + return toTyOrVectorOf(ty, mlir::Float32Type::get(ty.getContext())); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index fa4eb818dce5..06bfef0b299a 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -1,4 +1,5 @@ #include "TritonCPUToLLVM/Passes.h" +#include "TritonCPUTransforms/Passes.h" #include "TritonToTritonCPU/Passes.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" @@ -34,6 +35,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); }); + m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps()); + }); + m.def("add_decompose_fp_conversions", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createDecomposeFpConversions()); + }); m.def("add_vector_to_scf", [](mlir::PassManager &pm, bool full_unroll, unsigned target_rank, bool lower_tensors) { mlir::VectorTransferToSCFOptions opts; From bc568a4dcb468b4b734bd66718af6638480106c4 Mon Sep 17 00:00:00 2001 From: Gregory Shimansky Date: Tue, 25 Jun 2024 15:34:03 -0500 Subject: [PATCH 036/165] Enabled simple build&test workflow, disabled old Integration Tests workflow because it fails to run (#33) * Enabled simple build&test workflow, disabled old Integration Tests workflow because it fails to run. Signed-off-by: Gregory Shimansky * Fixed integration-tests source file to disable triggers Signed-off-by: Gregory Shimansky --------- Signed-off-by: Gregory Shimansky --- .github/workflows/build-test.yml | 8 ++++++++ .github/workflows/integration-tests.yml | 15 ++++++++------- .github/workflows/integration-tests.yml.in | 16 +++++++++------- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 8c9bcca7cf28..767abd066b75 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -3,6 +3,14 @@ run-name: ${{ inputs.run_name }} on: workflow_dispatch: + pull_request: + branches: + - main + # You can name your branch dev-foo to get CI runs. + - 'dev-**' + push: + branches: + - main jobs: pre-commit: diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index d3cf4a869a7a..c86266885a98 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -9,13 +9,14 @@ name: Integration Tests on: workflow_dispatch: - pull_request: - branches-ignore: ['llvm-**'] - merge_group: - branches: [main, 'dev-**'] - types: [checks_requested] - push: - branches: [main] +# Disabled automatic triggers because tests in this workflow fail to run. +# pull_request: +# branches-ignore: ['llvm-**'] +# merge_group: +# branches: [main, 'dev-**'] +# types: [checks_requested] +# push: +# branches: [main] concurrency: group: ${{ github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} diff --git a/.github/workflows/integration-tests.yml.in b/.github/workflows/integration-tests.yml.in index cd6c81688e81..8f01a26a00d5 100644 --- a/.github/workflows/integration-tests.yml.in +++ b/.github/workflows/integration-tests.yml.in @@ -8,13 +8,15 @@ name: Integration Tests on: workflow_dispatch: - pull_request: - branches-ignore: ['llvm-**'] - merge_group: - branches: [main, 'dev-**'] - types: [checks_requested] - push: - branches: [main] +# Disabled automatic triggers because tests in this workflow fail to run. +# pull_request: +# # You can name your branch dev-foo to get CI runs. +# branches-ignore: ['llvm-**'] +# merge_group: +# branches: [main, 'dev-**'] +# types: [checks_requested] +# push: +# branches: [main] concurrency: group: ${{ github.ref }} From 57bce468eb27cd82566ed81f19ed8644ff905ac0 Mon Sep 17 00:00:00 2001 From: RuiqiGao Date: Tue, 25 Jun 2024 16:57:17 -0700 Subject: [PATCH 037/165] [BACKEND][CPU] Specify CPU target to native for GNU/Linux Arm (#34) --- python/triton/runtime/build.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 8c97f1f0b50d..44a8baefe65e 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -21,6 +21,7 @@ def quiet(): def _build(name, src, srcdir, library_dirs, include_dirs, libraries): suffix = sysconfig.get_config_var('EXT_SUFFIX') system = platform.system() + machine = platform.machine() so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) # try to avoid setuptools if possible cc = os.environ.get("CC") @@ -56,6 +57,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd += ["-std=c++17", "-fopenmp"] if src.endswith(".s"): cc_cmd += ["-gdwarf-5"] + if system == "Linux" and machine in ("aarch64", "arm64"): + # On Arm backend, some CPU (neoverse-v2) needs to be specified through -mcpu + cc_cmd += ["-mcpu=native"] ret = subprocess.check_call(cc_cmd) if ret == 0: return so From b0ef7b9936f743b015b73a495928580de09e0394 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 2 Jul 2024 12:28:34 -0500 Subject: [PATCH 038/165] Add conversions for mixed precision matmuls. (#32) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 13 +++ third_party/cpu/backend/compiler.py | 7 +- .../cpu/include/TritonCPUTransforms/Passes.h | 3 + .../cpu/include/TritonCPUTransforms/Passes.td | 11 +- .../ConvertUnsupportedOps.cpp | 103 ++++++++++++++++-- third_party/cpu/triton_cpu.cc | 9 +- 6 files changed, 129 insertions(+), 17 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index d6b58a6848d2..6721dc54b89e 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3525,6 +3525,7 @@ def get_test_dot_double_rate_cases(): (16, 16, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)] +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize( "M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack, mma_nonk_size", @@ -3543,6 +3544,18 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty if is_interpreter(): if in_dtype == 'bfloat16': pytest.skip("bfloat16 is not supported in the interpreter") + elif is_cpu(): + if input_precision != "ieee": + pytest.skip(f"{input_precision} not supported on CPU") + if in_dtype == 'float8e4nv' or in_dtype == 'float8e5': + pytest.skip("float8e4nv and float8e5 not supported on CPU") + # This test kernel runs in a single thread and can take a long time + # for bigger sizes with the current codegen on CPU. Limit input sizes + # by default to get more reasonable tests execution time. + if os.environ.get('TRITON_CPU_TEST_DOT_FULL_SIZE', '0') != '1': + M = min(M, 64) + N = min(N, 64) + K = min(K, 32) else: if not is_hip() and (M < 16 or N < 16 or K < 16): pytest.skip("small dots are supported only on HIP at the moment") diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6a4e3d08535c..6ed2a6f6e111 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -93,8 +93,11 @@ def make_tttcir(self, mod, metadata, opt): # TTCIR -> Target TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - if self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features: - cpu.passes.ttcpuir.add_convert_unsupported_ops(pm) + promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features + # We don't have any lowering for mixed precision matmuls, so always use casts for now + convert_mixed_precision_matmul = True + cpu.passes.ttcpuir.add_convert_unsupported_ops(pm, promote_bf16_to_fp32, convert_mixed_precision_matmul) + if promote_bf16_to_fp32: cpu.passes.ttcpuir.add_decompose_fp_conversions(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index 035122aa98fb..213161fecc8a 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -19,6 +19,9 @@ namespace cpu { #include "cpu/include/TritonCPUTransforms/Passes.h.inc" std::unique_ptr> createConvertUnsupportedOps(); +std::unique_ptr> +createConvertUnsupportedOps(bool promoteBf16ToFp32, + bool convertMixedPrecisionMatmul); std::unique_ptr> createDecomposeFpConversions(); #define GEN_PASS_REGISTRATION diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index 0eb4910394fd..2e92bc42c6c5 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -10,7 +10,16 @@ def ConvertUnsupportedOps : Pass<"triton-cpu-add-casts-for-unsupported-ops", "ml by the target natively. Operations are converted to a supported data type with casts added for inputs and the result. }]; - // TODO: add options to specify which operations to convert. + + let options = [ + Option<"promoteBf16ToFp32", "promote-bf16-to-fp32", + "bool", /*default*/"false", + "Convert BF16 operations to FP32.">, + Option<"convertMixedPrecisionMatmul", "convert-mixed-precision-matmul", + "bool", /*default*/"false", + "Convert inputs of a mixed-precision matmul to a destination type.">, + ]; + let constructor = "mlir::triton::cpu::createConvertUnsupportedOps()"; let dependentDialects = ["mlir::arith::ArithDialect", diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index dec4970d1ccd..5d991b376902 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -11,8 +11,10 @@ namespace mlir { namespace triton { +namespace cpu { #define GEN_PASS_DEF_CONVERTUNSUPPORTEDOPS #include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu } // namespace triton } // namespace mlir @@ -165,24 +167,96 @@ struct ConvertBf16Abs : public OpRewritePattern { } }; +struct ConvertMixedPrecisionMatmul + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value acc = op.getAcc(); + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); + auto accTy = cast(acc.getType()); + auto resTy = cast(op.getType()); + + if (lhsTy.getElementType() == resTy.getElementType() && + rhsTy.getElementType() == resTy.getElementType() && + accTy.getElementType() == resTy.getElementType()) + return failure(); + + Type commonElemTy = resTy.getElementType(); + if (lhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth()) + commonElemTy = lhsTy; + if (rhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth()) + commonElemTy = rhsTy; + if (accTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth()) + commonElemTy = accTy; + + lhs = castElemTy(loc, lhs, commonElemTy, rewriter); + rhs = castElemTy(loc, rhs, commonElemTy, rewriter); + acc = castElemTy(loc, acc, commonElemTy, rewriter); + + Value newRes = rewriter.create( + loc, lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes()); + newRes = castElemTy(loc, newRes, resTy.getElementType(), rewriter); + + rewriter.replaceOp(op, newRes); + return success(); + } + + Value castElemTy(Location loc, Value val, Type elemTy, + PatternRewriter &rewriter) const { + auto valTy = cast(val.getType()); + if (valTy.getElementType() == elemTy) + return val; + + auto resTy = toTyOrVectorOf(valTy, elemTy); + if (valTy.getElementType().isInteger()) { + if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth()) + return rewriter.create(loc, resTy, val); + else + return rewriter.create(loc, resTy, val); + } else { + if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth()) + return rewriter.create(loc, resTy, val); + else + return rewriter.create(loc, resTy, val); + } + } +}; + struct ConvertUnsupportedOps - : public triton::impl::ConvertUnsupportedOpsBase { - using ConvertUnsupportedOpsBase::ConvertUnsupportedOpsBase; + : public triton::cpu::impl::ConvertUnsupportedOpsBase< + ConvertUnsupportedOps> { + ConvertUnsupportedOps() = default; + + ConvertUnsupportedOps(bool promoteBf16ToFp32, + bool convertMixedPrecisionMatmul) { + this->promoteBf16ToFp32 = promoteBf16ToFp32; + this->convertMixedPrecisionMatmul = convertMixedPrecisionMatmul; + } void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); RewritePatternSet patterns(context); - patterns.add>(context); - patterns.add>(context); - patterns.add>(context); - patterns.add>(context); - patterns.add>(context); - patterns.add(context); - patterns.add(context); - - patterns.add(context); + if (promoteBf16ToFp32) { + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + } + if (convertMixedPrecisionMatmul) { + patterns.add(context); + } if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) return signalPassFailure(); @@ -199,6 +273,13 @@ std::unique_ptr> createConvertUnsupportedOps() { return std::make_unique(); } +std::unique_ptr> +createConvertUnsupportedOps(bool promoteBf16ToFp32, + bool convertMixedPrecisionMatmul) { + return std::make_unique(promoteBf16ToFp32, + convertMixedPrecisionMatmul); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 06bfef0b299a..748b72fe549d 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -35,9 +35,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); }); - m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps()); - }); + m.def("add_convert_unsupported_ops", + [](mlir::PassManager &pm, bool promote_bf16_to_fp32, + bool convert_mixed_precision_matmul) { + pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps( + promote_bf16_to_fp32, convert_mixed_precision_matmul)); + }); m.def("add_decompose_fp_conversions", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createDecomposeFpConversions()); }); From ad60606d06fad4e27f4978a8c4fcf1550c815ce4 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Wed, 3 Jul 2024 22:58:06 +0200 Subject: [PATCH 039/165] [Op support] Support 'get_num_programs' (#39) --- python/test/unit/language/test_core.py | 1 + third_party/cpu/backend/driver.py | 8 +++---- .../cpu/include/TritonCPUToLLVM/Passes.td | 18 +++++---------- .../cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp | 9 ++++++++ .../TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp | 22 ++++++++++++++++++- 5 files changed, 40 insertions(+), 18 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 6721dc54b89e..52862798392d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6820,6 +6820,7 @@ def temp_in_loop(Z, N: tl.constexpr, BLOCK: tl.constexpr): assert (acc == out).all() +@pytest.mark.cpu @pytest.mark.interpreter def test_num_programs(device): # Assuming that the kernel is launched with a grid of (11, 21, 31) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index ebce4229d7af..2e2afd05ddbb 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -128,8 +128,8 @@ def format_of(ty): arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' kernel_fn_args = [i for i in signature.keys() if i not in constants] kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else '' - kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + - ", " if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t" + kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + ", " + if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t" # generate glue code src = f""" @@ -236,7 +236,7 @@ def format_of(ty): for (size_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; - (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ); }} return; }} @@ -254,7 +254,7 @@ def format_of(ty): #pragma omp parallel for schedule(static) num_threads(max_threads.value()) for (size_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; - (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z); + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ); }} }} diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index 0759ddbf7925..d8b010f35660 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -5,9 +5,7 @@ include "mlir/Pass/PassBase.td" def FuncOpToLLVM : Pass<"triton-cpu-func-op-to-llvm", "mlir::ModuleOp"> { let summary = "Convert FuncOp to LLVM for CPU."; - let description = [{ - - }]; + let description = [{}]; let constructor = "mlir::triton::cpu::createFuncOpToLLVMPass()"; let dependentDialects = ["mlir::arith::ArithDialect", @@ -19,9 +17,7 @@ def FuncOpToLLVM : Pass<"triton-cpu-func-op-to-llvm", "mlir::ModuleOp"> { def MemoryOpToLLVM : Pass<"triton-cpu-memory-op-to-llvm", "mlir::ModuleOp"> { let summary = "Convert Triton memory operations to LLVM for CPU."; - let description = [{ - - }]; + let description = [{}]; let constructor = "mlir::triton::cpu::createMemoryOpToLLVMPass()"; let dependentDialects = ["mlir::arith::ArithDialect", @@ -34,9 +30,7 @@ def MemoryOpToLLVM : Pass<"triton-cpu-memory-op-to-llvm", "mlir::ModuleOp"> { def GetProgramIdOpToLLVM : Pass<"triton-cpu-get-program-id-op-to-llvm", "mlir::ModuleOp"> { let summary = "Convert Triton GetProgramId to LLVM for CPU."; - let description = [{ - - }]; + let description = [{}]; let constructor = "mlir::triton::cpu::createGetProgramIdOpToLLVMPass()"; let dependentDialects = ["mlir::LLVM::LLVMDialect", @@ -45,8 +39,7 @@ def GetProgramIdOpToLLVM : Pass<"triton-cpu-get-program-id-op-to-llvm", "mlir::M def LowerMultiReduction : Pass<"triton-cpu-lower-multi-reduction", "mlir::triton::FuncOp"> { let summary = "Convert multi-dimensional reductions."; - let description = [{ - }]; + let description = [{}]; let constructor = "mlir::triton::cpu::createLowerMultiReductionPass()"; let dependentDialects = ["mlir::vector::VectorDialect", @@ -56,8 +49,7 @@ def LowerMultiReduction : Pass<"triton-cpu-lower-multi-reduction", "mlir::triton def AtomicOpsToLLVM : Pass<"triton-cpu-atomic-ops-to-llvm", "mlir::ModuleOp"> { let summary = "Convert Triton atomic operations to LLVM."; - let description = [{ - }]; + let description = [{}]; let constructor = "mlir::triton::cpu::createAtomicOpsToLLVMPass()"; let dependentDialects = ["mlir::vector::VectorDialect", diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp index 4c5257fcff4c..0d6db8e13154 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -79,6 +79,9 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { amendedInputTy.push_back(i32_ty); amendedInputTy.push_back(i32_ty); amendedInputTy.push_back(i32_ty); + amendedInputTy.push_back(ui32_ty); + amendedInputTy.push_back(ui32_ty); + amendedInputTy.push_back(ui32_ty); auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, funcTy.getResults()); // 2. Modify the argument attributes to add new arguments. @@ -90,6 +93,9 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); amendedAttrs.push_back(rewriter.getNamedAttr( funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); // 3. Add a new arguments to the region @@ -99,6 +105,9 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { region.addArgument(i32_ty, loc); region.addArgument(i32_ty, loc); region.addArgument(i32_ty, loc); + region.addArgument(ui32_ty, loc); + region.addArgument(ui32_ty, loc); + region.addArgument(ui32_ty, loc); rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), amendedFuncOp.end()); return amendedFuncOp; diff --git a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp index 4c593f1ff7aa..cdf45de6adc3 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp @@ -52,7 +52,26 @@ struct GetProgramIdOpConversion : public OpConversionPattern { auto funcOp = op->getParentOfType(); assert(funcOp && "expected LLVM::FuncOp as a parent of GetProgramIdOp"); auto args = funcOp.getArguments(); - // Last three args are x, y, z program ids. + // First three of last six args are x, y, z program ids. + auto argIdx = args.size() - 6 + op.getAxisAsInt(); + assert(argIdx < args.size() && "out-of-bounds arg index"); + assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); + rewriter.replaceOp(op, args[argIdx]); + return success(); + } +}; + +struct GetNumProgramsOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GetNumProgramsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto funcOp = op->getParentOfType(); + assert(funcOp && "expected LLVM::FuncOp as a parent of GetNumProgramsOp"); + auto args = funcOp.getArguments(); + // Last three of args are gridX, gridY, gridZ (bounds) of grid. auto argIdx = args.size() - 3 + op.getAxisAsInt(); assert(argIdx < args.size() && "out-of-bounds arg index"); assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); @@ -77,6 +96,7 @@ struct GetProgramIdOpToLLVM RewritePatternSet patterns(context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); From 6534a26ff507e4fa7c4c74d4c7d0db95f5b5a3b6 Mon Sep 17 00:00:00 2001 From: Ruiqi Gao Date: Mon, 8 Jul 2024 11:35:45 -0700 Subject: [PATCH 040/165] Add fast-math option: allow fp reduction reassociation --- python/src/llvm.cc | 42 ++++++++++++++++++----------- third_party/cpu/backend/compiler.py | 7 +++-- third_party/cpu/triton_cpu.cc | 14 +++++++--- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/python/src/llvm.cc b/python/src/llvm.cc index ba45fbb51462..6d3bb79e25eb 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -1,4 +1,4 @@ -#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp +#include "mlir/IR/BuiltinOps.h" // mlir::ModuleOp #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "triton/Tools/Sys/GetEnv.hpp" @@ -47,7 +47,8 @@ using namespace llvm; std::unique_ptr createTargetMachine(llvm::Module *module, std::string proc, - bool enable_fp_fusion, const std::string &features) { + bool enable_fp_fusion, const std::string &features, + bool enable_fast_math = false) { std::string error; auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); @@ -55,10 +56,21 @@ createTargetMachine(llvm::Module *module, std::string proc, bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); if (enable_fp_fusion) opt.AllowFPOpFusion = llvm::FPOpFusion::Fast; - opt.UnsafeFPMath = false; - opt.NoInfsFPMath = false; - opt.NoNaNsFPMath = true; - opt.TrapUnreachable = true; + + if (enable_fast_math) { + opt.UnsafeFPMath = true; + opt.NoInfsFPMath = true; + opt.NoNaNsFPMath = true; + opt.NoTrappingFPMath = true; + opt.NoSignedZerosFPMath = true; + opt.ApproxFuncFPMath = true; + } else { + opt.UnsafeFPMath = false; + opt.NoInfsFPMath = false; + opt.NoNaNsFPMath = true; + opt.TrapUnreachable = true; + } + opt.MCOptions.AsmVerbose = true; opt.MCOptions.PreserveAsmComments = true; std::unique_ptr machine{target->createTargetMachine( @@ -69,12 +81,10 @@ createTargetMachine(llvm::Module *module, std::string proc, return machine; } -std::string translateLLVMIRToASM(llvm::Module &module, - const std::string &triple, - const std::string &proc, - const std::string &features, - const std::vector &flags, - bool enable_fp_fusion, bool isObject) { +std::string translateLLVMIRToASM( + llvm::Module &module, const std::string &triple, const std::string &proc, + const std::string &features, const std::vector &flags, + bool enable_fp_fusion, bool isObject, bool enable_fast_math = false) { using namespace mlir; // options auto options = llvm::cl::getRegisteredOptions(); @@ -136,7 +146,8 @@ std::string translateLLVMIRToASM(llvm::Module &module, // create machine module.setTargetTriple(triple); - auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features); + auto machine = createTargetMachine(&module, proc, enable_fp_fusion, features, + enable_fast_math); // set data layout module.setDataLayout(machine->createDataLayout()); // emit machine code @@ -419,7 +430,8 @@ void init_triton_llvm(py::module &&m) { m.def( "translate_to_host_asm", - [](std::string llvmIR, bool enable_fp_fusion) -> py::object { + [](std::string llvmIR, bool enable_fp_fusion, + bool enable_fast_math) -> py::object { std::string res; { // when allow_threads goes out of scope, gil will be released @@ -439,7 +451,7 @@ void init_triton_llvm(py::module &&m) { res = translateLLVMIRToASM(*module, llvm::sys::getDefaultTargetTriple(), llvm::sys::getHostCPUName().str(), "", {}, - enable_fp_fusion, false); + enable_fp_fusion, false, enable_fast_math); } return py::str(res); }, diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6ed2a6f6e111..4c616164c517 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -22,6 +22,7 @@ class CPUOptions: allowed_dot_input_precisions: Tuple[str] = ("ieee", ) allow_fp8e4nv: bool = False enable_fp_fusion: bool = True + enable_fast_math: bool = False # TODO: We may introduce CPU-specific options like # of cores. @@ -49,6 +50,8 @@ def __init__(self, target: tuple) -> None: def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} + if not "enable_fast_math" in args: + args["enable_fast_math"] = os.getenv("TRITON_CPU_FAST_MATH", "0") == "1" return CPUOptions(**args) def pack_metadata(self, metadata): @@ -124,7 +127,7 @@ def make_llir(src, metadata, options): cpu.passes.ttcpuir.add_triton_cpu_to_llvmir_pipeline(pm) passes.convert.add_math_to_llvmir(pm) cpu.passes.ttcpuir.add_math_to_libm(pm) - cpu.passes.ttcpuir.add_vector_to_llvmir(pm) + cpu.passes.ttcpuir.add_vector_to_llvmir(pm, options.enable_fast_math) cpu.passes.ttcpuir.add_memref_to_llvmir(pm) passes.convert.add_arith_to_llvmir(pm) cpu.passes.ttcpuir.add_func_to_llvmir(pm) @@ -158,7 +161,7 @@ def make_llir(src, metadata, options): @staticmethod def make_asm(src, metadata, options): - return llvm.translate_to_host_asm(src, options.enable_fp_fusion) + return llvm.translate_to_host_asm(src, options.enable_fp_fusion, options.enable_fast_math) def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 748b72fe549d..5977b6f36d17 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -56,9 +56,17 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { pm.addNestedPass( mlir::triton::cpu::createLowerMultiReductionPass()); }); - m.def("add_vector_to_llvmir", [](mlir::PassManager &pm) { - pm.addPass(mlir::createConvertVectorToLLVMPass()); - }); + m.def("add_vector_to_llvmir", + [](mlir::PassManager &pm, bool reassoc_fp_reduction) { + mlir::ConvertVectorToLLVMPassOptions opts; + opts.reassociateFPReductions = reassoc_fp_reduction; + // opts.force32BitVectorIndices = true; + // opts.amx = false; + // opts.armNeon = false; + // opts.armSVE = false; + // opts.x86Vector = false; + pm.addPass(mlir::createConvertVectorToLLVMPass(opts)); + }); m.def("add_lower_affine", [](mlir::PassManager &pm) { pm.addPass(mlir::createLowerAffinePass()); }); From e7bb5dc41eb8322d95d461313be0315c8026f7f2 Mon Sep 17 00:00:00 2001 From: Ruiqi Gao Date: Mon, 8 Jul 2024 11:51:00 -0700 Subject: [PATCH 041/165] Change the lowering option for vector.multi_reduction from InnerParallel to InnerReduction. --- third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp b/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp index 5c18f15d9d1b..74f81cb0f9cc 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp @@ -35,7 +35,9 @@ struct LowerMultiReduction MLIRContext *context = op->getContext(); RewritePatternSet loweringPatterns(context); - vector::VectorMultiReductionLowering options; + // The default lowering option is InnerParallel + vector::VectorMultiReductionLowering options = + vector::VectorMultiReductionLowering::InnerReduction; vector::populateVectorMultiReductionLoweringPatterns(loweringPatterns, options); From d713fad1469ec9cead2902e9b95d793a0e0cda2a Mon Sep 17 00:00:00 2001 From: Ruiqi Gao Date: Tue, 9 Jul 2024 12:30:44 -0700 Subject: [PATCH 042/165] Fix: TrapUnreachable is not controled by fast-math, we set it unconditionally --- python/src/llvm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 6d3bb79e25eb..b0c67a5e6ff0 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -68,9 +68,9 @@ createTargetMachine(llvm::Module *module, std::string proc, opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; - opt.TrapUnreachable = true; } + opt.TrapUnreachable = true; opt.MCOptions.AsmVerbose = true; opt.MCOptions.PreserveAsmComments = true; std::unique_ptr machine{target->createTargetMachine( From 6a323ab6d34861bdaac3a3e223e0ecc270283de0 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Wed, 17 Jul 2024 11:58:13 -0400 Subject: [PATCH 043/165] [so] Compile asm to .so as part of staged lowering (#53) ... instead of doing it at load time. This more closely mirrors how the other backends work. E.g. for the CUDA backend, we compile ptx -> cubin as part of staged lowering too. --- third_party/cpu/backend/compiler.py | 23 ++++++++++++++++++++++- third_party/cpu/backend/driver.py | 27 +++++++++------------------ 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 4c616164c517..6e0dc594ee22 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -1,12 +1,16 @@ import functools import hashlib import os +import tempfile +from pathlib import Path from dataclasses import dataclass from typing import Any, Tuple from triton._C.libtriton import cpu, ir, llvm, passes from triton.backends.compiler import BaseBackend, GPUTarget +from triton.runtime.build import _build +import triton.backends.cpu.driver as cpu_driver @dataclass(frozen=True) @@ -43,7 +47,7 @@ def supports_target(target: GPUTarget): def __init__(self, target: tuple) -> None: super().__init__(target) - self.binary_ext = "asm" + self.binary_ext = "so" self.cpu_arch = llvm.get_cpu_tripple().split("-")[0] self.cpu_name = llvm.get_cpu_name() self.cpu_features = llvm.get_cpu_features() @@ -163,12 +167,29 @@ def make_llir(src, metadata, options): def make_asm(src, metadata, options): return llvm.translate_to_host_asm(src, options.enable_fp_fusion, options.enable_fast_math) + @staticmethod + def make_so(src, metadata, options): + with tempfile.TemporaryDirectory() as tmpdir: + asm_path = os.path.join(tmpdir, "kernel.s") + Path(asm_path).write_text(src) + so = _build( + "kernel", + asm_path, + tmpdir, + cpu_driver.library_dir, + cpu_driver.include_dir, + ["gcc", "m"], + ) + with open(so, "rb") as f: + return f.read() + def add_stages(self, stages, options): stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options) stages["tttcir"] = lambda src, metadata: self.make_tttcir(src, metadata, options) stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options) stages["asm"] = lambda src, metadata: self.make_asm(src, metadata, options) + stages["so"] = lambda src, metadata: self.make_so(src, metadata, options) @functools.lru_cache() def hash(self): diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 2e2afd05ddbb..a5eef2ed9fb8 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -47,24 +47,15 @@ def __new__(cls): def __init__(self): pass - def load_binary(self, name, src, shared_mem, device): - # src actually holds asm text, compile to a shared library. - key = hashlib.md5(src).hexdigest() - cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") - if cache_path is None: - with tempfile.TemporaryDirectory() as tmpdir: - asm_path = os.path.join(tmpdir, "kernel.s") - Path(asm_path).write_bytes(src) - Path("kernel.s").write_bytes(src) - so = _build(name, asm_path, tmpdir, library_dir, include_dir, ["gcc", "m"]) - with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) - import ctypes - lib = ctypes.cdll.LoadLibrary(cache_path) - fn_ptr = getattr(lib, name) - fn_ptr_as_void_p = ctypes.cast(fn_ptr, ctypes.c_void_p).value - return (fn_ptr, fn_ptr_as_void_p, 0, 0) + def load_binary(self, name, kernel, shared_mem, device): + with tempfile.NamedTemporaryFile(mode="wb", suffix=".so") as f: + f.write(kernel) + f.flush() + import ctypes + lib = ctypes.cdll.LoadLibrary(f.name) + fn_ptr = getattr(lib, name) + fn_ptr_as_void_p = ctypes.cast(fn_ptr, ctypes.c_void_p).value + return (lib, fn_ptr_as_void_p, 0, 0) def get_device_properties(self, *args): return {"max_shared_mem": 0} From 0992b4d27382bfc2d0aec080f262018bebdb39fa Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 17 Jul 2024 11:21:52 -0500 Subject: [PATCH 044/165] Add libdevice for CPU. (#52) Signed-off-by: Ilya Enkovich --- python/src/ir.cc | 48 ++++++++++ python/test/unit/cpu/test_libdevice.py | 48 ++++++++++ python/triton/language/extra/cpu/__init__.py | 3 + python/triton/language/extra/cpu/libdevice.py | 96 +++++++++++++++++++ third_party/cpu/backend/compiler.py | 1 + .../ConvertElementwiseOps.cpp | 12 +++ 6 files changed, 208 insertions(+) create mode 100644 python/test/unit/cpu/test_libdevice.py create mode 100644 python/triton/language/extra/cpu/__init__.py create mode 100644 python/triton/language/extra/cpu/libdevice.py diff --git a/python/src/ir.cc b/python/src/ir.cc index dd70dc13b511..61e1cdee92d6 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1621,10 +1621,50 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_cosh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_sin", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_sinh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_tan", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_tanh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_acos", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_acosh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_asin", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_asinh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_atan", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) + .def("create_atanh", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_log", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); @@ -1633,6 +1673,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_log10", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_erf", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); @@ -1653,6 +1697,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_cbrt", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_reduce", [](TritonOpBuilder &self, std::vector operands, int axis) -> OpState { return self.create(operands, axis); }) diff --git a/python/test/unit/cpu/test_libdevice.py b/python/test/unit/cpu/test_libdevice.py new file mode 100644 index 000000000000..22cb6286f3b8 --- /dev/null +++ b/python/test/unit/cpu/test_libdevice.py @@ -0,0 +1,48 @@ +import os +import pytest +import torch + +import triton +import triton.language as tl +from triton.language.extra import libdevice + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + +float_dtypes = ['bfloat16', 'float16', 'float32', 'float64'] + + +@pytest.mark.parametrize("dtype_str", float_dtypes) +@pytest.mark.parametrize("math_fn", [ + "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "log", "log2", + "log10", "sin", "sinh", "tan", "tanh" +]) +@pytest.mark.parametrize("size", [1, 4, 16, 64]) +def test_libdevice(dtype_str, math_fn, size, device): + if not is_cpu(): + pytest.skip("This test is CPU-specific") + + @triton.jit + def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + idxs = tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = getattr(libdevice, MATH_FN)(x) + tl.store(dst + idxs, y) + + src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) + if math_fn == "acosh": + src = src.abs() + 1 + res = torch.empty(src.shape, dtype=getattr(torch, dtype_str), device=device) + kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size) + if math_fn == "cbrt": + ref = src.pow(1 / 3) + else: + ref = getattr(src, math_fn)() + torch.testing.assert_close(ref, res) diff --git a/python/triton/language/extra/cpu/__init__.py b/python/triton/language/extra/cpu/__init__.py new file mode 100644 index 000000000000..229b57d87d65 --- /dev/null +++ b/python/triton/language/extra/cpu/__init__.py @@ -0,0 +1,3 @@ +from . import libdevice + +__all__ = ["libdevice"] diff --git a/python/triton/language/extra/cpu/libdevice.py b/python/triton/language/extra/cpu/libdevice.py new file mode 100644 index 000000000000..d7e8cdde3cfd --- /dev/null +++ b/python/triton/language/extra/cpu/libdevice.py @@ -0,0 +1,96 @@ +from triton.language import core, tensor + + +@core.extern +def acos(arg0, _builder=None): + return tensor(_builder.create_acos(arg0.handle), arg0.type) + + +@core.extern +def acosh(arg0, _builder=None): + return tensor(_builder.create_acosh(arg0.handle), arg0.type) + + +@core.extern +def asin(arg0, _builder=None): + return tensor(_builder.create_asin(arg0.handle), arg0.type) + + +@core.extern +def asinh(arg0, _builder=None): + return tensor(_builder.create_asinh(arg0.handle), arg0.type) + + +@core.extern +def atan(arg0, _builder=None): + return tensor(_builder.create_atan(arg0.handle), arg0.type) + + +@core.extern +def atanh(arg0, _builder=None): + return tensor(_builder.create_atanh(arg0.handle), arg0.type) + + +@core.extern +def cbrt(arg0, _builder=None): + return tensor(_builder.create_cbrt(arg0.handle), arg0.type) + + +@core.extern +def cos(arg0, _builder=None): + return tensor(_builder.create_cos(arg0.handle), arg0.type) + + +@core.extern +def cosh(arg0, _builder=None): + return tensor(_builder.create_cosh(arg0.handle), arg0.type) + + +@core.extern +def erf(arg0, _builder=None): + return tensor(_builder.create_erf(arg0.handle), arg0.type) + + +@core.extern +def exp(arg0, _builder=None): + return tensor(_builder.create_exp(arg0.handle), arg0.type) + + +@core.extern +def exp2(arg0, _builder=None): + return tensor(_builder.create_exp2(arg0.handle), arg0.type) + + +@core.extern +def log(arg0, _builder=None): + return tensor(_builder.create_log(arg0.handle), arg0.type) + + +@core.extern +def log2(arg0, _builder=None): + return tensor(_builder.create_log2(arg0.handle), arg0.type) + + +@core.extern +def log10(arg0, _builder=None): + return tensor(_builder.create_log10(arg0.handle), arg0.type) + + +@core.extern +def sin(arg0, _builder=None): + return tensor(_builder.create_sin(arg0.handle), arg0.type) + + +@core.extern +def sinh(arg0, _builder=None): + return tensor(_builder.create_sinh(arg0.handle), arg0.type) + + +@core.extern +def tan(arg0, _builder=None): + return tensor(_builder.create_tan(arg0.handle), arg0.type) + + +@core.extern +def tanh(arg0, _builder=None): + return tensor(_builder.create_tanh(arg0.handle), arg0.type) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6e0dc594ee22..ef04c5089d2e 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -17,6 +17,7 @@ class CPUOptions: # GPU-specific options are used in several places. # For now, we just provide dummy values. + backend_name: str = "cpu" num_warps: int = 0 num_stages: int = 0 num_ctas: int = 0 diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 7edf15f2e921..783fab131862 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -197,11 +197,23 @@ struct ConvertElementwiseOps patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); patterns.add>( typeConverter, context); From caf43d0c2259c65d8c3fcfd0124b89e4275ba485 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Wed, 17 Jul 2024 18:31:30 +0200 Subject: [PATCH 045/165] [Op support] Dot3D support (#43) This commit adds implementation for 3 dimensional dot operation. Signed-off-by: Dmitrii Makarenko --- python/test/unit/language/test_core.py | 12 ++++- .../lib/TritonToTritonCPU/ConvertDotOp.cpp | 50 +++++++++++++++---- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 52862798392d..5c1b2cbece0b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4027,6 +4027,7 @@ def make_finite(x, dtype): or "tcgen05.mma.cta_group::1.kind::f16" in ptx) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str", [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) @@ -4055,12 +4056,21 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + elif is_cpu(): + input_precision = "ieee" + # TODO(dmitriim): investigate the reason why + # can be fixed with lower tolerance: + # E Mismatched elements: 94 / 32768 (0.287%) + # E Max absolute difference: 0.09375 + # E Max relative difference: 4.812 + if out_dtype_str == "float16" and in_dtype_str == "float16": + pytest.skip(f"{out_dtype_str} with M = {M}, N = {N}, K = {K} has low precision. Not clear why.") else: input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" if not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16): pytest.skip("small dots are supported only on HIP at the moment") - if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": + if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32" and not is_cpu(): if not is_interpreter() and triton.runtime.driver.active.utils.get_device_properties( triton.runtime.driver.active.get_current_device())["max_shared_mem"] < 131072: pytest.skip( diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp index 51a5f42fa63a..06cfb0d834d0 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp @@ -54,17 +54,45 @@ struct DotOpConversion : public OpConversionPattern { Value a = rewriter.getRemappedValue(op.getA()); Value b = rewriter.getRemappedValue(op.getB()); Value c = rewriter.getRemappedValue(op.getC()); - auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); - auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); - auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); - auto iteratorTypes = rewriter.getArrayAttr( - {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, vector::IteratorType::reduction)}); - rewriter.replaceOpWithNewOp( - op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), - iteratorTypes); - return success(); + + auto aType = cast(a.getType()); + auto bType = cast(b.getType()); + auto cType = cast(c.getType()); + assert(aType.getRank() == bType.getRank() && + bType.getRank() == cType.getRank() && + "Mixed ranks, not 2d or 3d matmul, unknown type of op"); + + uint32_t rank = aType.getRank(); + if (rank == 2) { + auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); + auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); + auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); + auto iteratorTypes = rewriter.getArrayAttr( + {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, + vector::IteratorType::reduction)}); + rewriter.replaceOpWithNewOp( + op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), + iteratorTypes); + return success(); + } else if (rank == 3) { + auto aMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx); + auto bMap = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); + auto cMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx); + auto iteratorTypes = rewriter.getArrayAttr( + {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, + vector::IteratorType::reduction)}); + rewriter.replaceOpWithNewOp( + op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), + iteratorTypes); + return success(); + } + + return failure(); } }; From 848e43e5e33e8f09cf18ab896a70142577fb6062 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 17 Jul 2024 15:56:38 -0500 Subject: [PATCH 046/165] Support FP8 conversions for CPU. (#40) * Mark tests enabled on CPU. Signed-off-by: Ilya Enkovich * Allow tf32 input for CPU. Signed-off-by: Ilya Enkovich * Add conversions for fp8 types on CPU. Signed-off-by: Ilya Enkovich * Add test_conversions.py to CI for CPU. Signed-off-by: Ilya Enkovich * Fix formatting. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich Co-authored-by: Minjang Kim --- .github/workflows/build-test.yml | 1 + python/test/unit/language/test_conversions.py | 49 +- python/test/unit/language/test_core.py | 10 +- third_party/cpu/backend/compiler.py | 11 +- .../cpu/include/TritonCPUTransforms/Passes.h | 3 + .../cpu/include/TritonCPUTransforms/Passes.td | 11 +- .../DecomposeFpConversions.cpp | 483 +++++++++++++++++- .../cpu/lib/TritonCPUTransforms/OptCommon.h | 136 ++++- .../ConvertElementwiseOps.cpp | 40 ++ third_party/cpu/triton_cpu.cc | 9 +- 10 files changed, 689 insertions(+), 64 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 767abd066b75..a2cdc22e1920 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -79,3 +79,4 @@ jobs: - name: Run python unit tests run: | python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu + python -m pytest -s -n 32 --device cpu python/test/unit/language/test_conversions.py diff --git a/python/test/unit/language/test_conversions.py b/python/test/unit/language/test_conversions.py index 7cb4a82bbe46..25607c3dbafd 100644 --- a/python/test/unit/language/test_conversions.py +++ b/python/test/unit/language/test_conversions.py @@ -10,6 +10,9 @@ from triton._internal_testing import is_cuda, is_hip, is_hip_mi300 +def is_cpu(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cpu" + def matching_int(dtype): if dtype.primitive_bitwidth == 8: return torch.int8 @@ -272,24 +275,24 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia ]) def test_typeconvert_upcast(src_dtype, dst_dtype, device): - # On HIP, fp8e4nv upcasting is only supported to bf16 and fp16, and it's only supported on MI300. - if is_cuda(): - if ((src_dtype == 'float8e4nv' and torch.cuda.get_device_capability(0) < (8, 9)) - or src_dtype in ('float8e4b8', 'float8e5b16')): - # If the dtype should error out in the given device, we assert that and return - with pytest.raises(triton.CompilationError, match="not supported in this architecture"): - launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) - return - elif is_hip(): - if src_dtype == 'float8e4nv' and ( - dst_dtype == 'float32' or ((dst_dtype in ('bfloat16')) and not is_hip_mi300())): - pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture") - if (src_dtype in ('float8e4b15') or - (src_dtype in ('float8e4b8', 'float8e5b16') and not is_hip_mi300())): - # If the dtype should error out in the given device, we assert that and return - with pytest.raises(triton.CompilationError, match="not supported in this architecture"): - launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) - return + # On HIP, fp8e4nv upcasting is only supported to bf16, and it's only supported on MI300. + if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or not is_hip_mi300()): + # pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture") + # If the dtype should error out in the given device, we assert that and return + with pytest.raises(triton.CompilationError, match="not supported in this architecture"): + launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) + return + + if src_dtype in ('float8e4b8', 'float8e4b15') and is_cpu(): + pytest.skip(f"Conversion from {src_dtype} to {dst_dtype} is not supported on CPU") + + if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9)) + or (src_dtype in ('float8e4b15') and is_hip()) + or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_hip_mi300()))): + # If the dtype should error out in the given device, we assert that and return + with pytest.raises(triton.CompilationError, match="not supported in this architecture"): + launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) + return # dtype : (exponent_bits, mantissa_bits, exponent_bias, max_repr) stuff = { @@ -329,16 +332,18 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device): ('float16', 'float8e4b8', 'rtne', 0x5b80), ]) def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): + if is_cpu() and dst_dtype not in ['float8e5', 'float8e4nv', 'float8e5b16']: + pytest.skip(f"Conversion from {src_dtype} to {dst_dtype} is not supported on CPU") if is_cuda(): if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0): pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+") - if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0): - pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") + if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and (is_hip() or (is_cuda() and torch.cuda.get_device_capability(0) < (9, 0))): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+") - if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne': - pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") + if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and (is_cuda() or (is_hip() and not is_hip_mi300())): + pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU MI300") if is_hip(): if dst_dtype == 'float8e5' and rounding == 'rtne': diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5c1b2cbece0b..a07691c329ea 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -534,6 +534,7 @@ def _min_max_integral_mod_value(dtype_x, dtype_y) -> Optional[int]: return max_info.min, max_info.max +@pytest.mark.cpu def test_dtype_codegen(): for dtype in dtypes_with_bfloat16: full_name = f"triton.language.{dtype}" @@ -3545,10 +3546,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty if in_dtype == 'bfloat16': pytest.skip("bfloat16 is not supported in the interpreter") elif is_cpu(): - if input_precision != "ieee": - pytest.skip(f"{input_precision} not supported on CPU") - if in_dtype == 'float8e4nv' or in_dtype == 'float8e5': - pytest.skip("float8e4nv and float8e5 not supported on CPU") # This test kernel runs in a single thread and can take a long time # for bigger sizes with the current codegen on CPU. Limit input sizes # by default to get more reasonable tests execution time. @@ -4517,6 +4514,7 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_ torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("cache", ["", ".ca", ".cg", ".cv"]) def test_load_cache_modifier(cache, device): @@ -4562,6 +4560,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): assert 'ld.global.cg' not in ptx +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("N", [16, 10, 11, 1024]) @pytest.mark.parametrize("num_ctas", num_ctas_list) @@ -4589,6 +4588,7 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): torch.testing.assert_close(dst[:N], src[:N], atol=1e-6, rtol=0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("has_hints", [False, True]) def test_vectorization_hints(has_hints, device): @@ -4642,6 +4642,7 @@ def _kernel(out_ptr, N: tl.constexpr, BLOCK_N: tl.constexpr): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("cache", ["", ".wb", ".cg", ".cs", ".wt"]) def test_store_cache_modifier(cache, device): @@ -4706,6 +4707,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): assert 'st.global.wt' in ptx +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("eviction_policy", ["", "evict_last", "evict_first"]) def test_store_eviction_policy(eviction_policy, device): diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index ef04c5089d2e..ccf0ace48770 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -24,9 +24,11 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False - allowed_dot_input_precisions: Tuple[str] = ("ieee", ) - allow_fp8e4nv: bool = False + allowed_dot_input_precisions: Tuple[str] = ("ieee", "tf32", "tf32x3") + allow_fp8e4nv: bool = True + allow_fp8e4b15: bool = True enable_fp_fusion: bool = True + max_num_imprecise_acc_default: int = 0 enable_fast_math: bool = False # TODO: We may introduce CPU-specific options like # of cores. @@ -105,8 +107,9 @@ def make_tttcir(self, mod, metadata, opt): # We don't have any lowering for mixed precision matmuls, so always use casts for now convert_mixed_precision_matmul = True cpu.passes.ttcpuir.add_convert_unsupported_ops(pm, promote_bf16_to_fp32, convert_mixed_precision_matmul) - if promote_bf16_to_fp32: - cpu.passes.ttcpuir.add_decompose_fp_conversions(pm) + decompose_bf16_conv = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features + decompose_fp8_conv = True + cpu.passes.ttcpuir.add_decompose_fp_conversions(pm, decompose_bf16_conv, decompose_fp8_conv) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index 213161fecc8a..89810b7ce526 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -23,6 +23,9 @@ std::unique_ptr> createConvertUnsupportedOps(bool promoteBf16ToFp32, bool convertMixedPrecisionMatmul); std::unique_ptr> createDecomposeFpConversions(); +std::unique_ptr> +createDecomposeFpConversions(bool decomposeBf16Conversions, + bool decomposeFp8Conversions); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUTransforms/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index 2e92bc42c6c5..d2273873310d 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -36,7 +36,16 @@ def DecomposeFpConversions : Pass<"triton-cpu-decompose-fp-conversions", "mlir:: intrinsics. This pass transforms such conversions into vector code instead. }]; - // TODO: add options to specify which FP conversions to decompose. + + let options = [ + Option<"decomposeBf16Conversions", "decompose-bf16-conversions", + "bool", /*default*/"false", + "Lower BF16 conversions to arith operations.">, + Option<"decomposeFp8Conversions", "decompose-fp8-conversions", + "bool", /*default*/"false", + "Lower FP8 conversions to arith operations.">, + ]; + let constructor = "mlir::triton::cpu::createDecomposeFpConversions()"; let dependentDialects = ["mlir::arith::ArithDialect", diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp index a82958b3ae2b..02d7087bf0a2 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -12,8 +12,10 @@ namespace mlir { namespace triton { +namespace cpu { #define GEN_PASS_DEF_DECOMPOSEFPCONVERSIONS #include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu } // namespace triton } // namespace mlir @@ -33,33 +35,479 @@ struct Fp32ToBf16Conversion : public OpRewritePattern { return failure(); Location loc = op.getLoc(); - Value i32Src = - rewriter.create(loc, toInt32(src.getType()), src); - TypedAttr shiftValAttr = rewriter.getI32IntegerAttr(16); - if (auto vecTy = dyn_cast(i32Src.getType())) - shiftValAttr = SplatElementsAttr::get(vecTy, shiftValAttr); - Value shiftedSrc = rewriter.create( - loc, i32Src, rewriter.create(loc, shiftValAttr)); - Value i16Res = rewriter.create(loc, toInt16(src.getType()), - shiftedSrc); - Value res = rewriter.create(loc, op.getType(), i16Res); + Value i32Src = op_bitcast(toInt32(src.getType()), src); + Value shiftedSrc = op_lshr(i32Src, cst_like(i32Src, 16)); + Value i16Res = op_trunci(toInt16(src.getType()), shiftedSrc); + Value res = op_bitcast(op.getType(), i16Res); + rewriter.replaceOp(op, res); + return success(); + } +}; + +typedef std::function FpToFpConvFn; + +// Convert FP8 to FP16/FP32. +Value convertFp8(Location loc, Value src, int srcExpBits, int srcExpBias, + Type dstFpTy, PatternRewriter &rewriter) { + assert(srcExpBits >= 4 && srcExpBits <= 5 && "Unexpect FP8 type conversion"); + assert(srcExpBias >= 0 && srcExpBias <= 16 && "Unexpect FP8 type conversion"); + assert((dstFpTy.isF16() || dstFpTy.isF32()) && + "Unsupported FP8 type conversion"); + Type srcTy = src.getType(); + Type dstTy = toTyOrVectorOf(srcTy, dstFpTy); + int dstExpBits = dstFpTy.isF16() ? 5 : 8; + int dstMantBits = dstFpTy.isF16() ? 10 : 23; + int dstExpBias = dstFpTy.isF16() ? 15 : 127; + int srcMantBits = 7 - srcExpBits; + assert(dstExpBias >= srcExpBias && "Unsupported FP8 type conversion"); + Type dstIntTy = + dstFpTy.isF16() ? rewriter.getI16Type() : rewriter.getI32Type(); + Value i8Src = op_bitcast(toInt8(srcTy), src); + Value intSrc = op_zext(toTyOrVectorOf(srcTy, dstIntTy), i8Src); + Value shiftedVal; + if (srcExpBits != dstExpBits) { + Value sign = op_and(intSrc, cst_like(intSrc, 0x80)); + Value nosign = op_and(intSrc, cst_like(intSrc, 0x7f)); + shiftedVal = op_addi( + op_shl(sign, cst_like(sign, dstFpTy.getIntOrFloatBitWidth() - 8)), + op_shl(nosign, cst_like(nosign, dstMantBits - srcMantBits))); + } else { + shiftedVal = + op_shl(intSrc, cst_like(intSrc, dstFpTy.getIntOrFloatBitWidth() - 8)); + } + Value res = op_bitcast(dstTy, shiftedVal); + if (srcExpBias != dstExpBias) { + double scale = pow(2, dstExpBias - srcExpBias); + res = op_mulf(res, cst_like(res, scale)); + } + return res; +} + +Value convertFp8E4M3ToFp16(Location loc, Value src, PatternRewriter &rewriter) { + return convertFp8(loc, src, 4, 7, rewriter.getF16Type(), rewriter); +} + +Value convertFp8E5M2ToFp16(Location loc, Value src, PatternRewriter &rewriter) { + return convertFp8(loc, src, 5, 15, rewriter.getF16Type(), rewriter); +} + +Value convertFp8E5M2B16ToFp16(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Res = convertFp8(loc, src, 5, 16, rewriter.getF32Type(), rewriter); + return rewriter.create(loc, toFp16(src.getType()), f32Res); +} + +Value convertFp8E4M3ToBf16(Location loc, Value src, PatternRewriter &rewriter) { + Value f32Res = convertFp8(loc, src, 4, 7, rewriter.getF32Type(), rewriter); + return rewriter.create(loc, toBf16(src.getType()), f32Res); +} + +Value convertFp8E5M2ToBf16(Location loc, Value src, PatternRewriter &rewriter) { + Value f32Res = convertFp8(loc, src, 5, 15, rewriter.getF32Type(), rewriter); + return rewriter.create(loc, toBf16(src.getType()), f32Res); +} + +Value convertFp8E5M2B16ToBf16(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Res = convertFp8(loc, src, 5, 16, rewriter.getF32Type(), rewriter); + return rewriter.create(loc, toBf16(src.getType()), f32Res); +} + +Value convertFp8E4M3ToFp32(Location loc, Value src, PatternRewriter &rewriter) { + return convertFp8(loc, src, 4, 7, rewriter.getF32Type(), rewriter); +} + +Value convertFp8E5M2ToFp32(Location loc, Value src, PatternRewriter &rewriter) { + return convertFp8(loc, src, 5, 15, rewriter.getF32Type(), rewriter); +} + +Value convertFp8E5M2B16ToFp32(Location loc, Value src, + PatternRewriter &rewriter) { + return convertFp8(loc, src, 5, 16, rewriter.getF32Type(), rewriter); +} + +// Convert F16/FP32 to FP8. +Value convertToFp8(Location loc, Value src, Type dstFpTy, int dstExpBits, + int dstExpBias, bool rtneRounding, bool unsignedZero, + PatternRewriter &rewriter) { + assert(dstExpBits >= 4 && dstExpBits <= 5 && "Unexpect FP8 type conversion"); + assert(dstExpBias >= 0 && dstExpBias <= 16 && "Unexpect FP8 type conversion"); + Type srcTy = src.getType(); + Type srcFpTy = getElemTyOrTy(srcTy); + assert((srcFpTy.isF16() || srcFpTy.isF32()) && + "Unsupported FP8 type conversion"); + int dstMantBits = 7 - dstExpBits; + int srcExpBits = srcFpTy.isF16() ? 5 : 8; + int srcMantBits = srcFpTy.isF16() ? 10 : 23; + int srcExpBias = srcFpTy.isF16() ? 15 : 127; + assert(dstExpBias <= srcExpBias && "Unsupported FP8 type conversion"); + Type srcIntTy = + srcFpTy.isF16() ? rewriter.getI16Type() : rewriter.getI32Type(); + Value intSrc = op_bitcast(toTyOrVectorOf(srcTy, srcIntTy), src); + // Extract sign and put it to the proper place for FP8. + Value sign = + op_lshr(op_and(intSrc, cst_like(intSrc, 1 << (srcExpBits + srcMantBits))), + cst_like(intSrc, srcFpTy.getIntOrFloatBitWidth() - 8)); + // Extract mantissa and exponent. + Value mant = op_and(intSrc, cst_like(intSrc, (1 << srcMantBits) - 1)); + Value exp = op_and(op_lshr(intSrc, cst_like(intSrc, srcMantBits)), + cst_like(intSrc, (1 << srcExpBits) - 1)); + Value isZeroExp = op_icmp_eq(exp, cst_like(exp, 0)); + mant = op_select(isZeroExp, mant, + op_addi(mant, cst_like(mant, 1 << srcMantBits))); + exp = op_select(isZeroExp, exp, op_subi(exp, cst_like(exp, 1))); + double adjustment = pow(0.5, srcMantBits - dstMantBits); + exp = op_subi(exp, cst_like(exp, srcExpBias - dstExpBias)); + mant = op_mulf(op_sitofp(srcTy, mant), cst_like(src, adjustment)); + // Make exponent non-negative. + if (dstExpBias - srcExpBias <= -8) { + // In this case we don't have enough mantissa bits, so can round to 0. + Value mask = op_icmp_sgt(exp, cst_like(exp, -8)); + exp = op_select(mask, exp, cst_like(exp, 0)); + mant = op_select(mask, mant, cst_like(mant, 0.0)); + } + if (dstExpBias - srcExpBias <= -4) { + Value mask = op_icmp_sgt(exp, cst_like(exp, -4)); + exp = op_select(mask, exp, op_addi(exp, cst_like(exp, 4))); + mant = op_select(mask, mant, op_mulf(mant, cst_like(mant, 0.0625))); + } + if (dstExpBias - srcExpBias <= -2) { + Value mask = op_icmp_sgt(exp, cst_like(exp, -2)); + exp = op_select(mask, exp, op_addi(exp, cst_like(exp, 2))); + mant = op_select(mask, mant, op_mulf(mant, cst_like(mant, 0.25))); + } + if (dstExpBias - srcExpBias <= -1) { + Value mask = op_icmp_sgt(exp, cst_like(exp, -1)); + exp = op_select(mask, exp, op_addi(exp, cst_like(exp, 1))); + mant = op_select(mask, mant, op_mulf(mant, cst_like(mant, 0.5))); + } + if (rtneRounding) { + // Bring the value to the range [2 ** 10/23, 2 ** 11/24] + // where the representable fp16/fp32 map exactly to integers. + // Addition has RTNE semantics. + Value offs = cst_like(mant, static_cast(1 << srcMantBits)); + mant = op_addf(mant, offs); + mant = op_subf(mant, offs); + } + mant = op_fptosi(toTyOrVectorOf(srcTy, srcIntTy), mant); + + Value res = + op_addi(sign, op_addi(op_shl(exp, cst_like(exp, 7 - dstExpBits)), mant)); + res = op_trunci(toInt8(srcTy), res); + if (unsignedZero) { + Value isNegativeZero = op_icmp_eq(res, cst_like(res, 0x80)); + res = op_select(isNegativeZero, cst_like(res, 0), res); + } + res = op_bitcast(toTyOrVectorOf(srcTy, dstFpTy), res); + return res; +} + +Value convertFp16ToFp8E4M3Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + // TODO: Fix type to Float8E4M3FN. + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNUZType(), 4, 7, false, + false, rewriter); +} + +Value convertFp16ToFp8E4M3Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + // TODO: Fix type to Float8E4M3FN. + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNUZType(), 4, 7, true, + false, rewriter); +} + +Value convertFp16ToFp8E5M2Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + Type srcTy = src.getType(); + Type dstTy = toFp8E5M2(srcTy); + Value i16Src = op_bitcast(toInt16(srcTy), src); + Value shiftedSrc = op_lshr(i16Src, cst_like(i16Src, 8)); + Value i8Res = op_trunci(toInt8(srcTy), shiftedSrc); + Value res = op_bitcast(dstTy, i8Res); + return res; +} + +Value convertFp16ToFp8E5M2Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + Type srcTy = src.getType(); + Type dstTy = toFp8E5M2(srcTy); + Value i16Src = op_bitcast(toInt16(srcTy), src); + Value sign = op_and(i16Src, cst_like(i16Src, 0x8000)); + Value truncated = op_and(i16Src, cst_like(i16Src, 0x7f00)); + Value tail = op_and(i16Src, cst_like(i16Src, 0xff)); + Value odd_trunc = op_icmp_ne(op_and(truncated, cst_like(truncated, 0x100)), + cst_like(truncated, 0)); + Value round_up = + op_or(op_icmp_ugt(tail, cst_like(tail, 0x80)), + op_and(op_icmp_eq(tail, cst_like(tail, 0x80)), odd_trunc)); + // Skip round-up if it leads to inf/nan. + round_up = + op_and(round_up, op_icmp_ult(truncated, cst_like(truncated, 0x7b00))); + truncated = op_select( + round_up, op_addi(truncated, cst_like(truncated, 0x100)), truncated); + + Value res = op_lshr(op_or(truncated, sign), cst_like(truncated, 8)); + res = op_bitcast(dstTy, op_trunci(toInt8(srcTy), res)); + return res; +} + +Value convertFp16ToFp8E5M2B16Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + false, true, rewriter); +} + +Value convertFp16ToFp8E5M2B16Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + true, true, rewriter); +} + +Value convertBf16ToFp8E4M3Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + // TODO: Fix type to Float8E4M3FN. + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNUZType(), 4, 7, + false, false, rewriter); +} + +Value convertBf16ToFp8E4M3Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + // TODO: Fix type to Float8E4M3FN. + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNUZType(), 4, 7, true, + false, rewriter); +} + +Value convertBf16ToFp8E5M2Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2Type(), 5, 15, false, + false, rewriter); +} + +Value convertBf16ToFp8E5M2Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2Type(), 5, 15, true, + false, rewriter); +} + +Value convertBf16ToFp8E5M2B16Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + false, true, rewriter); +} + +Value convertBf16ToFp8E5M2B16Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + Value f32Src = + rewriter.create(loc, toFp32(src.getType()), src); + return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + true, true, rewriter); +} + +Value convertFp32ToFp8E4M3Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + // TODO: Fix type to Float8E4M3FN. + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNUZType(), 4, 7, false, + false, rewriter); +} + +Value convertFp32ToFp8E4M3Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + // TODO: Fix type to Float8E4M3FN. + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNUZType(), 4, 7, true, + false, rewriter); +} + +Value convertFp32ToFp8E5M2Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E5M2Type(), 5, 15, false, + false, rewriter); +} + +Value convertFp32ToFp8E5M2Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E5M2Type(), 5, 15, true, + false, rewriter); +} + +Value convertFp32ToFp8E5M2B16Rtz(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E5M2FNUZType(), 5, 16, false, + true, rewriter); +} + +Value convertFp32ToFp8E5M2B16Rtne(Location loc, Value src, + PatternRewriter &rewriter) { + return convertToFp8(loc, src, rewriter.getFloat8E5M2FNUZType(), 5, 16, true, + true, rewriter); +} + +FpToFpConvFn +getFpToFpConversionFn(Type srcTy, Type dstTy, + std::optional roundMode) { + // TODO: Float8E4M3FNUZType is used for both float8e4nv and float8e4b8 by + // frontend. float8e4b8 tests are skipped for CPU so we interpret this type as + // float8e4nv. Needs to be fixed. See get_fp8e4nv_ty at ir.cc. + auto F8E4M3TyID = TypeID::get(); + auto F8E5M2TyID = TypeID::get(); + auto F8E5M2B16TyID = TypeID::get(); + auto F16TyID = TypeID::get(); + auto BF16TyID = TypeID::get(); + auto F32TyID = TypeID::get(); + + static DenseMap, FpToFpConvFn> fpExtFnMap = { + {{F8E4M3TyID, F16TyID}, convertFp8E4M3ToFp16}, + {{F8E5M2TyID, F16TyID}, convertFp8E5M2ToFp16}, + {{F8E5M2B16TyID, F16TyID}, convertFp8E5M2B16ToFp16}, + {{F8E4M3TyID, BF16TyID}, convertFp8E4M3ToBf16}, + {{F8E5M2TyID, BF16TyID}, convertFp8E5M2ToBf16}, + {{F8E5M2B16TyID, BF16TyID}, convertFp8E5M2B16ToBf16}, + {{F8E4M3TyID, F32TyID}, convertFp8E4M3ToFp32}, + {{F8E5M2TyID, F32TyID}, convertFp8E5M2ToFp32}, + {{F8E5M2B16TyID, F32TyID}, convertFp8E5M2B16ToFp32}, + }; + static DenseMap, FpToFpConvFn> + fpTruncFnMap = { + {{F16TyID, F8E4M3TyID, arith::RoundingMode::toward_zero}, + convertFp16ToFp8E4M3Rtz}, + {{F16TyID, F8E4M3TyID, arith::RoundingMode::to_nearest_even}, + convertFp16ToFp8E4M3Rtne}, + {{F16TyID, F8E5M2TyID, arith::RoundingMode::toward_zero}, + convertFp16ToFp8E5M2Rtz}, + {{F16TyID, F8E5M2TyID, arith::RoundingMode::to_nearest_even}, + convertFp16ToFp8E5M2Rtne}, + {{F16TyID, F8E5M2B16TyID, arith::RoundingMode::toward_zero}, + convertFp16ToFp8E5M2B16Rtz}, + {{F16TyID, F8E5M2B16TyID, arith::RoundingMode::to_nearest_even}, + convertFp16ToFp8E5M2B16Rtne}, + {{BF16TyID, F8E4M3TyID, arith::RoundingMode::toward_zero}, + convertBf16ToFp8E4M3Rtz}, + {{BF16TyID, F8E4M3TyID, arith::RoundingMode::to_nearest_even}, + convertBf16ToFp8E4M3Rtne}, + {{BF16TyID, F8E5M2TyID, arith::RoundingMode::toward_zero}, + convertBf16ToFp8E5M2Rtz}, + {{BF16TyID, F8E5M2TyID, arith::RoundingMode::to_nearest_even}, + convertBf16ToFp8E5M2Rtne}, + {{BF16TyID, F8E5M2B16TyID, arith::RoundingMode::toward_zero}, + convertBf16ToFp8E5M2B16Rtz}, + {{BF16TyID, F8E5M2B16TyID, arith::RoundingMode::to_nearest_even}, + convertBf16ToFp8E5M2B16Rtne}, + {{F32TyID, F8E4M3TyID, arith::RoundingMode::toward_zero}, + convertFp32ToFp8E4M3Rtz}, + {{F32TyID, F8E4M3TyID, arith::RoundingMode::to_nearest_even}, + convertFp32ToFp8E4M3Rtne}, + {{F32TyID, F8E5M2TyID, arith::RoundingMode::toward_zero}, + convertFp32ToFp8E5M2Rtz}, + {{F32TyID, F8E5M2TyID, arith::RoundingMode::to_nearest_even}, + convertFp32ToFp8E5M2Rtne}, + {{F32TyID, F8E5M2B16TyID, arith::RoundingMode::toward_zero}, + convertFp32ToFp8E5M2B16Rtz}, + {{F32TyID, F8E5M2B16TyID, arith::RoundingMode::to_nearest_even}, + convertFp32ToFp8E5M2B16Rtne}, + }; + + if (roundMode) { + auto key = + std::make_tuple(srcTy.getTypeID(), dstTy.getTypeID(), *roundMode); + if (fpTruncFnMap.count(key)) + return fpTruncFnMap.at(key); + } else { + auto key = std::make_tuple(srcTy.getTypeID(), dstTy.getTypeID()); + if (fpExtFnMap.count(key)) + return fpExtFnMap.at(key); + } + + return FpToFpConvFn(); +} + +Value convertFpToFp(Location loc, Value src, Type dstTy, + std::optional roundMode, + PatternRewriter &rewriter) { + Type srcTy = src.getType(); + Type srcElemTy = getElemTyOrTy(srcTy); + Type dstElemTy = getElemTyOrTy(dstTy); + auto fn = getFpToFpConversionFn(srcElemTy, dstElemTy, roundMode); + if (!fn) { + llvm::errs() << "Unsupported conversion from " << srcElemTy << " to " + << dstElemTy; + if (roundMode) + llvm::errs() << " with rounding mode " + << arith::stringifyRoundingMode(*roundMode); + llvm::errs() << "\n"; + llvm_unreachable(""); + } + return fn(loc, src, rewriter); +} + +struct RewriteTruncFp8 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncFOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value src = op.getIn(); + Type srcTy = src.getType(); + Type dstTy = op.getType(); + if (!isFp8(dstTy)) + return failure(); + Value res = convertFpToFp(loc, src, dstTy, op.getRoundingmode(), rewriter); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct RewriteExtFp8 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value src = op.getIn(); + Type srcTy = src.getType(); + if (!isFp8(srcTy)) + return failure(); + Type dstTy = op.getType(); + Value res = convertFpToFp(loc, src, dstTy, std::nullopt, rewriter); rewriter.replaceOp(op, res); return success(); } }; struct DecomposeFpConversions - : public triton::impl::DecomposeFpConversionsBase { - using DecomposeFpConversionsBase::DecomposeFpConversionsBase; + : public triton::cpu::impl::DecomposeFpConversionsBase< + DecomposeFpConversions> { + DecomposeFpConversions() = default; - DecomposeFpConversions() : DecomposeFpConversionsBase() {} + DecomposeFpConversions(bool decomposeBf16Conversions, + bool decomposeFp8Conversions) { + this->decomposeBf16Conversions = decomposeBf16Conversions; + this->decomposeFp8Conversions = decomposeFp8Conversions; + } void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); RewritePatternSet patterns(context); - patterns.add(context); + if (decomposeBf16Conversions) + patterns.add(context); + if (decomposeFp8Conversions) { + patterns.add(context); + patterns.add(context); + } if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) return signalPassFailure(); @@ -76,6 +524,13 @@ std::unique_ptr> createDecomposeFpConversions() { return std::make_unique(); } +std::unique_ptr> +createDecomposeFpConversions(bool decomposeBf16Conversions, + bool decomposeFp8Conversions) { + return std::make_unique(decomposeBf16Conversions, + decomposeFp8Conversions); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h b/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h index a9fe054b8ede..ebf4ae723248 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h +++ b/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h @@ -1,46 +1,150 @@ #ifndef TRITONCPU_CONVERSION_TRITONCPUOPT_OPTCOMMON_H #define TRITONCPU_CONVERSION_TRITONCPUOPT_OPTCOMMON_H +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" namespace mlir { namespace triton { namespace cpu { -inline bool isTyOrVectorOf(mlir::Type ty, mlir::Type elemTy) { - if (auto vecTy = dyn_cast(ty)) - return vecTy.getElementType() == elemTy; - return ty == elemTy; +inline Type getElemTyOrTy(Type ty) { + if (auto vecTy = dyn_cast(ty)) + return vecTy.getElementType(); + return ty; } -inline bool isBf16(mlir::Type ty) { - return isTyOrVectorOf(ty, mlir::BFloat16Type::get(ty.getContext())); +inline bool isTyOrVectorOf(Type ty, Type elemTy) { + return getElemTyOrTy(ty) == elemTy; } -inline bool isFp32(mlir::Type ty) { - return isTyOrVectorOf(ty, mlir::Float32Type::get(ty.getContext())); +inline bool isBf16(Type ty) { + return isTyOrVectorOf(ty, BFloat16Type::get(ty.getContext())); } -inline mlir::Type toTyOrVectorOf(mlir::Type ty, mlir::Type elemTy) { - if (auto vecTy = dyn_cast(ty)) +inline bool isFp16(Type ty) { + return isTyOrVectorOf(ty, Float16Type::get(ty.getContext())); +} + +inline bool isFp32(Type ty) { + return isTyOrVectorOf(ty, Float32Type::get(ty.getContext())); +} + +inline bool isFp8(Type ty) { + Type elemTy = getElemTyOrTy(ty); + if (elemTy.isIntOrFloat() && !elemTy.isInteger()) + return elemTy.getIntOrFloatBitWidth() == 8; + return false; +} + +inline Type toTyOrVectorOf(Type ty, Type elemTy) { + if (auto vecTy = dyn_cast(ty)) return vecTy.cloneWith(std::nullopt, elemTy); return elemTy; } -inline mlir::Type toInt16(mlir::Type ty) { - return toTyOrVectorOf(ty, mlir::IntegerType::get(ty.getContext(), 16)); +inline Type toInt8(Type ty) { + return toTyOrVectorOf(ty, IntegerType::get(ty.getContext(), 8)); } -inline mlir::Type toInt32(mlir::Type ty) { - return toTyOrVectorOf(ty, mlir::IntegerType::get(ty.getContext(), 32)); +inline Type toInt16(Type ty) { + return toTyOrVectorOf(ty, IntegerType::get(ty.getContext(), 16)); } -inline mlir::Type toFp32(mlir::Type ty) { - return toTyOrVectorOf(ty, mlir::Float32Type::get(ty.getContext())); +inline Type toInt32(Type ty) { + return toTyOrVectorOf(ty, IntegerType::get(ty.getContext(), 32)); +} + +inline Type toFp8E5M2(Type ty) { + return toTyOrVectorOf(ty, Float8E5M2Type::get(ty.getContext())); +} + +inline Type toFp16(Type ty) { + return toTyOrVectorOf(ty, Float16Type::get(ty.getContext())); +} + +inline Type toBf16(Type ty) { + return toTyOrVectorOf(ty, BFloat16Type::get(ty.getContext())); +} + +inline Type toFp32(Type ty) { + return toTyOrVectorOf(ty, Float32Type::get(ty.getContext())); +} + +inline Value intCst(Location loc, Type ty, int64_t val, + PatternRewriter &rewriter) { + TypedAttr valAttr = IntegerAttr::get(getElemTyOrTy(ty), val); + if (auto vecTy = dyn_cast(ty)) + valAttr = SplatElementsAttr::get(vecTy, valAttr); + return rewriter.create(loc, valAttr); +} + +inline Value fpCst(Location loc, Type ty, double val, + PatternRewriter &rewriter) { + TypedAttr valAttr = FloatAttr::get(getElemTyOrTy(ty), val); + if (auto vecTy = dyn_cast(ty)) + valAttr = SplatElementsAttr::get(vecTy, valAttr); + return rewriter.create(loc, valAttr); +} + +template ::value, bool> = true> +Value cstLike(Location loc, Value tySrc, T val, PatternRewriter &rewriter) { + return intCst(loc, tySrc.getType(), val, rewriter); +} + +template ::value, bool> = true> +Value cstLike(Location loc, Value tySrc, T val, PatternRewriter &rewriter) { + return fpCst(loc, tySrc.getType(), val, rewriter); } } // namespace cpu } // namespace triton } // namespace mlir +#define int_cst(ty, val) intCst(loc, ty, val, rewriter) +#define cst_like(src, val) cstLike(loc, src, val, rewriter) + +#define op_addi(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_addf(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_subi(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_subf(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_mulf(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_bitcast(ty, val) rewriter.create(loc, ty, val) +#define op_lshr(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_shl(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_trunci(ty, val) rewriter.create(loc, ty, val) +#define op_zext(ty, val) rewriter.create(loc, ty, val) +#define op_sext(ty, val) rewriter.create(loc, ty, val) +#define op_and(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_or(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_minui(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_maxui(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_select(cond, val, other) \ + rewriter.create(loc, cond, val, other) +#define op_sitofp(ty, val) rewriter.create(loc, ty, val) +#define op_fptosi(ty, val) rewriter.create(loc, ty, val) + +#define op_icmp_eq(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::eq, lhs, rhs) +#define op_icmp_ne(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::ne, lhs, rhs) +#define op_icmp_ugt(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::ugt, lhs, rhs) +#define op_icmp_uge(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::uge, lhs, rhs) +#define op_icmp_ult(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::ult, lhs, rhs) +#define op_icmp_ule(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::ule, lhs, rhs) +#define op_icmp_sgt(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::sgt, lhs, rhs) +#define op_icmp_sge(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::sge, lhs, rhs) +#define op_icmp_slt(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::slt, lhs, rhs) +#define op_icmp_sle(lhs, rhs) \ + rewriter.create(loc, arith::CmpIPredicate::sle, lhs, rhs) + #endif diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 783fab131862..5a66813b0a16 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -55,6 +55,7 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -134,6 +135,44 @@ struct ClampFOpConversion : public OpConversionPattern { } }; +struct FpToFpOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcTy = src.getType(); + auto resTy = getTypeConverter()->convertType(op.getType()); + auto srcElemTy = isa(srcTy) + ? cast(srcTy).getElementType() + : srcTy; + auto resElemTy = isa(resTy) + ? cast(resTy).getElementType() + : resTy; + + if (srcElemTy.getIntOrFloatBitWidth() > resElemTy.getIntOrFloatBitWidth()) { + std::optional rounding = op.getRounding(); + assert(rounding && "Rounding mode expected for truncate conversions"); + auto roundingAttr = arith::RoundingModeAttr::get( + getContext(), *rounding == RoundingMode::RTZ + ? arith::RoundingMode::toward_zero + : arith::RoundingMode::to_nearest_even); + rewriter.replaceOpWithNewOp(op, resTy, src, roundingAttr, + nullptr); + return success(); + } + + if (srcElemTy.getIntOrFloatBitWidth() < resElemTy.getIntOrFloatBitWidth()) { + rewriter.replaceOpWithNewOp(op, resTy, src); + return success(); + } + + return failure(); + } +}; + struct ConvertElementwiseOps : public triton::impl::ConvertElementwiseOpsBase { using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; @@ -223,6 +262,7 @@ struct ConvertElementwiseOps typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 5977b6f36d17..11fd17a7c9fe 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -41,9 +41,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps( promote_bf16_to_fp32, convert_mixed_precision_matmul)); }); - m.def("add_decompose_fp_conversions", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createDecomposeFpConversions()); - }); + m.def("add_decompose_fp_conversions", + [](mlir::PassManager &pm, bool decomposeBf16Conversions, + bool decomposeFp8Conversions) { + pm.addPass(mlir::triton::cpu::createDecomposeFpConversions( + decomposeBf16Conversions, decomposeFp8Conversions)); + }); m.def("add_vector_to_scf", [](mlir::PassManager &pm, bool full_unroll, unsigned target_rank, bool lower_tensors) { mlir::VectorTransferToSCFOptions opts; From f578a9756af10ec698224672f24c4b29e66ad2e6 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Thu, 18 Jul 2024 11:29:42 -0700 Subject: [PATCH 047/165] [CPU] Support device_print for scalar types first (#54) --- third_party/cpu/backend/compiler.py | 2 + .../cpu/include/TritonCPUToLLVM/Passes.h | 1 + .../cpu/include/TritonCPUToLLVM/Passes.td | 9 + .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 2 + .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 199 ++++++++++++++++++ .../TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp | 24 +-- .../cpu/lib/TritonCPUToLLVM/Pipeline.cpp | 2 +- .../cpu/lib/TritonCPUToLLVM/Utility.cpp | 32 +++ third_party/cpu/lib/TritonCPUToLLVM/Utility.h | 14 ++ third_party/cpu/triton_cpu.cc | 7 +- 10 files changed, 271 insertions(+), 21 deletions(-) create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/Utility.cpp create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/Utility.h diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index ccf0ace48770..6e8a3b6c329c 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -154,6 +154,8 @@ def make_llir(src, metadata, options): llvm.init_targets() context = llvm.context() llvm_mod = llvm.to_module(mod, context) + if llvm_mod is None: + raise RuntimeError("Failed to convert to LLVM IR") llvm.set_host_target(llvm_mod) #if options.extern_libs: # paths = [path for (name, path) in options.extern_libs] diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index 7d739f1c32fe..288eed5256b4 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -25,6 +25,7 @@ std::unique_ptr> createMemoryOpToLLVMPass(); std::unique_ptr> createGetProgramIdOpToLLVMPass(); std::unique_ptr> createLowerMultiReductionPass(); std::unique_ptr> createAtomicOpsToLLVMPass(); +std::unique_ptr> createDebugOpsToLLVMPass(); void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); void registerTritonCPUToLLVMPipeline(); diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index d8b010f35660..06a9114d7696 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -57,4 +57,13 @@ def AtomicOpsToLLVM : Pass<"triton-cpu-atomic-ops-to-llvm", "mlir::ModuleOp"> { "mlir::triton::TritonDialect"]; } +def DebugOpsToLLVM : Pass<"triton-cpu-debug-ops-to-llvm", "mlir::ModuleOp"> { + let summary = "Convert Triton debug operations (prints and asserts) to LLVM."; + let description = [{}]; + let constructor = "mlir::triton::cpu::createDebugOpsToLLVMPass()"; + + let dependentDialects = ["mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt index 9e5f71f8d4e5..d469b7968682 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -1,11 +1,13 @@ add_triton_library(TritonCPUToLLVM AtomicOpsToLLVM.cpp + DebugOpsToLLVM.cpp FuncOpToLLVM.cpp GetProgramIdOpToLLVM.cpp LowerMultiReduction.cpp MemoryOpToLLVM.cpp Pipeline.cpp TypeConverter.cpp + Utility.cpp DEPENDS TritonCPUToLLVMConversionPassIncGen diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp new file mode 100644 index 000000000000..e6b6a531059c --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -0,0 +1,199 @@ +#include "TypeConverter.h" +#include "Utility.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_DEBUGOPSTOLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +// The code for the print is similar to the GPU's TargetInfo.cpp. +LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName("printf"); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *context = rewriter.getContext(); + + // int printf(char* format, ...) + SmallVector argsType{ptr_ty(context)}; + auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + auto op = rewriter.create(UnknownLoc::get(context), + funcName, funcType); + return op; +} + +void emitPrintf(ConversionPatternRewriter &rewriter, Value formatStrStart, + int /*formatStrByteCount*/, ValueRange args) { + auto loc = UnknownLoc::get(rewriter.getContext()); + SmallVector formatStrAndArgs{formatStrStart}; + for (auto arg : args) { + formatStrAndArgs.push_back(arg); + } + call(getPrintfDeclaration(rewriter), formatStrAndArgs); +} + +Value llPrintf(StringRef msg, ValueRange args, + ConversionPatternRewriter &rewriter, + int *formatStrByteCount = nullptr) { + assert(!msg.empty() && "printf with empty string not supported"); + llvm::SmallString<64> msgNewline(msg); + msgNewline.push_back('\n'); + msgNewline.push_back('\0'); + Value msgValue = + LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, + "printfFormat_", msgNewline); + emitPrintf(rewriter, msgValue, msgNewline.size_in_bytes(), args); + if (formatStrByteCount) + *formatStrByteCount = msgNewline.size_in_bytes(); + return msgValue; +} + +// TODO: This code is the same as the GPU-backend code. Consider refactoring. +std::string getFormatSubstr(Value value, bool hex = false, + std::optional width = std::nullopt) { + Type type = value.getType(); + if (isa(type)) { + return "%p"; + } + // Hex is "0x%0nx" or "0x%0nllx", where n is the number of hex digits in the + // type (so 4 for fp16, 8 for int32, 16 for int64). + if (hex) { + // Ignore `width` for `hex` values, pad to typeWidth. + std::string ret = "0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4); + if (type.getIntOrFloatBitWidth() > 32) { + ret += "ll"; + } + ret += "x"; + return ret; + } + + std::string prefix = "%"; + if (width.has_value()) { + prefix += std::to_string(*width); + } else if (hex) { + prefix += "0"; + prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); + } + + if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { + return prefix + "f"; + } else if (type.isSignedInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "lli"; + else + return prefix + "i"; + } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { + if (type.getIntOrFloatBitWidth() == 64) + return prefix + "llu"; + else + return prefix + "u"; + } + assert(false && "not supported type"); + return ""; +} + +// TritonCPU's device_print prints all values in the same line unlike GPUs +// and interpreter where each value is printed in a separate line. +struct PrintOpConversion : public ConvertOpToLLVMPattern { + explicit PrintOpConversion(LLVMTypeConverter &typeConverter) + : mlir::ConvertOpToLLVMPattern(typeConverter) {} + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + auto getPid = [&](int axis) { + return getProgramId(op->getParentOfType(), axis); + }; + SmallVector values = {getPid(0), getPid(1), getPid(2)}; + + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "(" << getFormatSubstr(values[0]) << ", " + << getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2]) + << ")" << op.getPrefix(); + + for (size_t i = 0; i < op.getNumOperands(); i++) { + auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); + if (dyn_cast(op.getOperand(i).getType())) { + llvm_unreachable("Not implemented for tensor types"); + } + + // Only support scalars for now. + assert(elems.size() == 1); + if (i != 0) { + os << ", "; + } + os << getFormatSubstr(elems[0], op.getHex()); + values.push_back(elems[0]); + } + + llPrintf(formatStr, values, rewriter); + rewriter.eraseOp(op); + return success(); + } +}; + +struct DebugOpsToLLVM + : public triton::impl::DebugOpsToLLVMBase { + using DebugOpsToLLVMBase::DebugOpsToLLVMBase; + + DebugOpsToLLVM() : DebugOpsToLLVMBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget convTarget(*context); + + RewritePatternSet patterns(context); + patterns.add(typeConverter); + // patterns.add(typeConverter); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // anonymous namespace + +namespace mlir::triton::cpu { + +std::unique_ptr> createDebugOpsToLLVMPass() { + return std::make_unique(); +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp index cdf45de6adc3..406b32cc7774 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/GetProgramIdOpToLLVM.cpp @@ -1,4 +1,5 @@ #include "TypeConverter.h" +#include "Utility.h" #include "cpu/include/TritonCPUToLLVM/Passes.h" @@ -43,39 +44,30 @@ class TritonLLVMConversionTarget : public ConversionTarget { }; // TODO: use enums to access struct fields. -struct GetProgramIdOpConversion : public OpConversionPattern { +struct GetProgramIdOpConversion + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(GetProgramIdOp op, OpAdaptor adaptor, + matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto funcOp = op->getParentOfType(); assert(funcOp && "expected LLVM::FuncOp as a parent of GetProgramIdOp"); - auto args = funcOp.getArguments(); - // First three of last six args are x, y, z program ids. - auto argIdx = args.size() - 6 + op.getAxisAsInt(); - assert(argIdx < args.size() && "out-of-bounds arg index"); - assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); - rewriter.replaceOp(op, args[argIdx]); + rewriter.replaceOp(op, getProgramId(funcOp, op.getAxisAsInt())); return success(); } }; struct GetNumProgramsOpConversion - : public OpConversionPattern { + : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(GetNumProgramsOp op, OpAdaptor adaptor, + matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto funcOp = op->getParentOfType(); assert(funcOp && "expected LLVM::FuncOp as a parent of GetNumProgramsOp"); - auto args = funcOp.getArguments(); - // Last three of args are gridX, gridY, gridZ (bounds) of grid. - auto argIdx = args.size() - 3 + op.getAxisAsInt(); - assert(argIdx < args.size() && "out-of-bounds arg index"); - assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); - rewriter.replaceOp(op, args[argIdx]); + rewriter.replaceOp(op, getNumPrograms(funcOp, op.getAxisAsInt())); return success(); } }; diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp index 0263a1e65214..8c02cc944b75 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp @@ -12,7 +12,7 @@ void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); pm.addPass(mlir::triton::cpu::createAtomicOpsToLLVMPass()); - // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); + pm.addPass(mlir::triton::cpu::createDebugOpsToLLVMPass()); } void registerTritonCPUToLLVMPipeline() { diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Utility.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Utility.cpp new file mode 100644 index 000000000000..e783497bd951 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/Utility.cpp @@ -0,0 +1,32 @@ +#include "Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace mlir::triton::cpu { + +Value getProgramId(mlir::FunctionOpInterface funcOp, int axis) { + auto args = funcOp.getArguments(); + assert(funcOp && args.size() >= 6); + assert(axis >= 0 && axis < 3); + + // The first three of the last six args are x, y, z program ids. + auto argIdx = args.size() - 6 + axis; + assert(argIdx < args.size() && "out-of-bounds arg index"); + assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); + return args[argIdx]; +} + +Value getNumPrograms(mlir::FunctionOpInterface funcOp, int axis) { + auto args = funcOp.getArguments(); + assert(funcOp && args.size() >= 6); + assert(axis >= 0 && axis < 3); + + // The last three of the args are gridX, gridY, gridZ (bounds) of grid. + auto argIdx = args.size() - 3 + axis; + assert(argIdx < args.size() && "out-of-bounds arg index"); + assert(args[argIdx].getType().isInteger(32) && "unexpected arg type"); + return args[argIdx]; +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Utility.h b/third_party/cpu/lib/TritonCPUToLLVM/Utility.h new file mode 100644 index 000000000000..53ffcc6651ff --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/Utility.h @@ -0,0 +1,14 @@ +#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H +#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +namespace mlir::triton::cpu { + +Value getProgramId(mlir::FunctionOpInterface funcOp, int axis); +Value getNumPrograms(mlir::FunctionOpInterface funcOp, int axis); + +} // namespace mlir::triton::cpu + +#endif diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 11fd17a7c9fe..1ce9419e8f60 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -26,9 +26,6 @@ namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { using namespace mlir::triton; - // m.def("add_to_llvmir", [](mlir::PassManager &pm) { - // pm.addPass(mlir::triton::createConvertTritonCPUToLLVMPass()); - // }); m.def("add_triton_to_triton_cpu_pipeline", [](mlir::PassManager &pm) { mlir::triton::cpu::tritonToTritonCPUPipelineBuilder(pm); }); @@ -99,7 +96,9 @@ void init_triton_cpu(py::module &&m) { m.def("find_kernel_names", [](mlir::ModuleOp &mod) { std::vector res; mod.walk([&](mlir::FunctionOpInterface funcOp) { - if (funcOp.getVisibility() == mlir::SymbolTable::Visibility::Public) + // Kernel functions are public and have a body. + if (!funcOp.getFunctionBody().empty() && + funcOp.getVisibility() == mlir::SymbolTable::Visibility::Public) res.push_back(funcOp.getName().str()); }); return res; From e5763464fa6105f758e025300c426da70cf19fab Mon Sep 17 00:00:00 2001 From: RuiqiGao Date: Fri, 19 Jul 2024 13:24:18 -0700 Subject: [PATCH 048/165] [TUTORIAL] Add matrix vector multiplication tutorial (#46) * Add matrix vector multiplication tutorial. * Fix: resolve review comment * Add test for: torch.matmul (reshape x to 2D), torch.matmul (transpose weight and flip the order), torch.nn.Linear * Change duplicated line_styles --- .../tutorials/matrix-vector-multiplication.py | 210 ++++++++++++++++++ 1 file changed, 210 insertions(+) create mode 100644 python/tutorials/matrix-vector-multiplication.py diff --git a/python/tutorials/matrix-vector-multiplication.py b/python/tutorials/matrix-vector-multiplication.py new file mode 100644 index 000000000000..288daab90178 --- /dev/null +++ b/python/tutorials/matrix-vector-multiplication.py @@ -0,0 +1,210 @@ +import torch + +import triton +import triton.language as tl + +BLOCK_SIZE_M = 1 +BLOCK_SIZE_N = 512 +USE_GPU = False +""" +Kernel for computing Y = A @ X, where A is a dense matrix with +M rows and N columns. +- Input X has shape (N,) +- A has shape (M, N) +- Output has shape (M,) +""" + + +@triton.jit +def gemv_kernel( + Y, + A, + X, + M, + N, + stride_am, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + start_m = tl.program_id(0) + rm = start_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = tl.arange(0, BLOCK_SIZE_N) + + A = A + (rm[:, None] * stride_am + rn[None, :]) + X = X + rn + + acc = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + for n in range(N, 0, -BLOCK_SIZE_N): + a = tl.load(A) + x = tl.load(X) + acc += tl.sum(a * x[None, :], axis=1) + A += BLOCK_SIZE_N + X += BLOCK_SIZE_N + + Y = Y + rm + tl.store(Y, acc) + + +def gemv( + weight: torch.Tensor, + x: torch.Tensor, + output: torch.Tensor, +): + assert weight.shape[1] == x.shape[0], "Incompatible dimensions" + assert weight.is_contiguous() and x.is_contiguous(), "Input and weight must be contiguous" + assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" + + M, N = weight.shape + + # TODO: Currently masked load is not supported yet. + assert M % BLOCK_SIZE_M == 0 and N % BLOCK_SIZE_N == 0, "Masking currently not supported, Matrix dimensions must be multiples of block size" + + if output is None: + # Allocates output. + output = torch.empty((M, ), device=x.device, dtype=x.dtype) + else: + assert output.shape == (M, ) and output.dtype == x.dtype, "Incompatible output" + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), ) + + gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N) + + return output + + +torch.manual_seed(0) + +triton.runtime.driver.set_active_to_cpu() + +weight = torch.randn((512, 1024), device='cpu', dtype=torch.float32) +x = torch.randn((1024), device='cpu', dtype=torch.float32) +triton_output = gemv(weight, x, None) +# torch.matmul will select bf16 kernels on Linux Arm if x is 1-d, which has lower precision. +# So we reshape x to be 2-d, which will invoke different kernels. +torch_output = torch.matmul(weight, x[:, None]).reshape(-1) +#print(f"triton_cpu_output_with_{weight.dtype}_inputs={triton_output}") +#print(f"torch_cpu_output_with_{weight.dtype}_inputs={torch_output}") +rtol = 0 +if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol): + print("✅ TritonCPU and TorchCPU match") +else: + print("❌ TritonCPU and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + +LINE_VALS = [ + 'triton-cpu-single', 'triton-cpu', 'triton-cpu-linear', 'torch-cpu-native', 'torch-cpu-compile', + 'torch-cpu-2d-native', 'torch-cpu-2d-compile', 'torch-cpu-transpose-native', 'torch-cpu-transpose-compile', + 'torch-cpu-linear' +] +LINE_NAMES = [ + 'TritonCPU 1', 'TritonCPU', 'TritonCPU Linear', 'TorchCPU (native)', 'TorchCPU (compile)', 'TorchCPU 2D (native)', + 'TorchCPU 2D (compile)', 'TorchCPU Transpose (native)', 'TorchCPU Transpose (compile)', 'TorchCPU Linear' +] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('blue', ':'), ('green', '--'), ('green', '-'), ('red', '--'), + ('red', '-'), ('yellow', '--'), ('yellow', '-'), ('purple', '-')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + weight = weight.to('cuda') + x = x.to('cuda') + triton_output = gemv(weight, x, None) + torch_output = torch.matmul(weight, x) + #print(f"triton_gpu_output_with_{a.dtype}_inputs={triton_output}") + #print(f"torch_gpu_output_with_{a.dtype}_inputs={torch_output}") + rtol = 0 + if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol): + print("✅ TritonGPU and TorchGPU match") + else: + print("❌ TritonGPU and TorchGPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('pink', '-'), ('cyan', '-')] + +# %% +# Seems like we're good to go! + +# %% +# Benchmark +# --------- + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["N"], # Argument names to use as an x-axis for the plot + x_vals=[512 * i for i in range(10, 21)], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel='GFLOPS', # Label name for the y-axis. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'gemv-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N})', + args={'M': 4096}, # Values for function arguments not in `x_names` and `y_name`. + )) +def benchmark(M, N, provider): + import os + + device = 'cpu' if 'cpu' in provider else 'cuda' + weight = torch.randn((M, N), device=device, dtype=torch.float32) + x = torch.randn((N), device=device, dtype=torch.float32) + + if device == 'cpu': + output = torch.empty((M), device=x.device, dtype=x.dtype) + triton.runtime.driver.set_active_to_cpu() + if 'single' in provider: + os.environ['TRITON_CPU_SINGLE_CORE'] = '1' + else: + os.unsetenv('TRITON_CPU_SINGLE_CORE') + + if 'transpose' in provider: + weight = torch.transpose(weight, 0, 1) + x = x[None, :] + output = output[None, :] + elif '2d' in provider: + x = x[:, None] + output = output[:, None] + else: + output = None + triton.runtime.driver.set_active_to_gpu() + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles) + elif provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) + elif provider == 'torch-cpu-native' or provider == 'torch-cpu-2d-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles, + is_cpu=True) + elif provider == 'torch-cpu-compile' or provider == 'torch-cpu-2d-compile': + compiled = torch.compile(torch.matmul) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(weight, x, out=output), quantiles=quantiles, + is_cpu=True) + elif provider == 'torch-cpu-transpose-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(x, weight, out=output), quantiles=quantiles, + is_cpu=True) + elif provider == 'torch-cpu-transpose-compile': + compiled = torch.compile(torch.matmul) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x, weight, out=output), quantiles=quantiles, + is_cpu=True) + elif provider == 'torch-cpu-linear': + weight = torch.nn.Linear(N, M, bias=False, device=weight.device, dtype=weight.dtype) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu-single': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, is_cpu=True) + elif provider == 'triton-cpu-linear': + # torch.nn.Linear.forward does not take preallocated output buffer, so we also do no provide output buffer for fair comparison + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, None), quantiles=quantiles, is_cpu=True) + perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) From d8452725ed580e32fb6c14b473d23c4d80f6d9a9 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 19 Jul 2024 17:32:19 -0500 Subject: [PATCH 049/165] Fix FuncOp lowering. (#61) --- python/test/unit/language/test_core.py | 7 +++++++ .../cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp | 20 ++++++++++--------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a07691c329ea..8ec022110470 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -319,6 +319,13 @@ def kernel(X, SIZE: tl.constexpr): x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) kernel[(1, )](x, SIZE=SIZE, num_warps=4) +@pytest.mark.cpu +def test_empty_kernel_scalar_arg(device): + @triton.jit + def kernel(x): + pass + + kernel[(1, )](2) def test_scalar_overflow(device): diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp index 0d6db8e13154..0f0193da57cf 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -88,16 +88,18 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern { SmallVector amendedAttrs; filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); SmallVector amendedArgAttrs; - if (funcOp.getAllArgAttrs()) + if (funcOp.getAllArgAttrs()) { amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); - amendedAttrs.push_back(rewriter.getNamedAttr( - funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back( + rewriter.getNamedAttr(funcOp.getArgAttrsAttrName(), + rewriter.getArrayAttr(amendedArgAttrs))); + } // 3. Add a new arguments to the region auto amendedFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); From a5874034cd8d812df3e7ce0a632fd3ea8c6d98f4 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Fri, 19 Jul 2024 15:34:33 -0700 Subject: [PATCH 050/165] [CPU] Easy: remove the old initial boilerplate code (#59) --- include/triton/Conversion/CMakeLists.txt | 3 -- .../Conversion/TritonCPUToLLVM/CMakeLists.txt | 3 -- .../TritonCPUToLLVM/CPUTargetInfo.h | 22 ---------- .../Conversion/TritonCPUToLLVM/Passes.h | 29 ------------ .../Conversion/TritonCPUToLLVM/Passes.td | 25 ----------- .../PatternTritonCPUOpToLLVM.h | 44 ------------------- .../TritonCPUToLLVM/TypeConverter.h | 22 ---------- .../Conversion/TritonCPUToLLVM/Utility.h | 21 --------- .../TritonToTritonCPU/CMakeLists.txt | 3 -- .../Conversion/TritonToTritonCPU/Passes.h | 15 ------- .../Conversion/TritonToTritonCPU/Passes.td | 23 ---------- .../TritonToTritonCPU/TritonToTritonCPUPass.h | 18 -------- third_party/cpu/triton_cpu.cc | 2 - 13 files changed, 230 deletions(-) delete mode 100644 include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt delete mode 100644 include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h delete mode 100644 include/triton/Conversion/TritonCPUToLLVM/Passes.h delete mode 100644 include/triton/Conversion/TritonCPUToLLVM/Passes.td delete mode 100644 include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h delete mode 100644 include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h delete mode 100644 include/triton/Conversion/TritonCPUToLLVM/Utility.h delete mode 100644 include/triton/Conversion/TritonToTritonCPU/CMakeLists.txt delete mode 100644 include/triton/Conversion/TritonToTritonCPU/Passes.h delete mode 100644 include/triton/Conversion/TritonToTritonCPU/Passes.td delete mode 100644 include/triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h diff --git a/include/triton/Conversion/CMakeLists.txt b/include/triton/Conversion/CMakeLists.txt index 3b8a95e1ecf7..730f5cadd246 100644 --- a/include/triton/Conversion/CMakeLists.txt +++ b/include/triton/Conversion/CMakeLists.txt @@ -1,5 +1,2 @@ -# TODO(minjang): I will remove these scratches soon. -# add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonGPUToLLVM) -# add_subdirectory(TritonToTritonCPU) add_subdirectory(TritonToTritonGPU) diff --git a/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt b/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt deleted file mode 100644 index 0936dff12d91..000000000000 --- a/include/triton/Conversion/TritonCPUToLLVM/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonCPUToLLVM) -add_public_tablegen_target(TritonCPUToLLVMConversionPassIncGen) diff --git a/include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h b/include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h deleted file mode 100644 index 66f6b57b1c57..000000000000 --- a/include/triton/Conversion/TritonCPUToLLVM/CPUTargetInfo.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H -#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H - -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Conversion/MLIRTypes.h" - -namespace mlir::triton::cpu { -class CPUTargetInfo { -public: - // Note: we may revisit for different CPU ISAs like AVX and Neon. - CPUTargetInfo() {} - - Value programId(ConversionPatternRewriter &rewriter, Location loc, - LLVM::LLVMFuncOp funcOp, int axis) const; - - void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, - int formatStrByteCount, ValueRange args) const; - - ~CPUTargetInfo() {} -}; -} // namespace mlir::triton::cpu -#endif // TRITON_CONVERSION_TRITONCPU_TO_LLVM_TARGETINFOBASE_H diff --git a/include/triton/Conversion/TritonCPUToLLVM/Passes.h b/include/triton/Conversion/TritonCPUToLLVM/Passes.h deleted file mode 100644 index f06efc13d004..000000000000 --- a/include/triton/Conversion/TritonCPUToLLVM/Passes.h +++ /dev/null @@ -1,29 +0,0 @@ -#ifndef TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H -#define TRITONCPU_CONVERSION_TRITONCPUTOLLVM_PASSES_H - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - -#include - -namespace mlir { - -class ModuleOp; -template class OperationPass; - -namespace triton { - -#define GEN_PASS_DECL -#include "triton/Conversion/TritonCPUToLLVM/Passes.h.inc" - -std::unique_ptr> createConvertTritonCPUToLLVMPass(); - -#define GEN_PASS_REGISTRATION -#include "triton/Conversion/TritonCPUToLLVM/Passes.h.inc" - -} // namespace triton - -} // namespace mlir - -#endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/Passes.td b/include/triton/Conversion/TritonCPUToLLVM/Passes.td deleted file mode 100644 index a0bfd65c3d28..000000000000 --- a/include/triton/Conversion/TritonCPUToLLVM/Passes.td +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef TRITONCPU_CONVERSION_PASSES -#define TRITONCPU_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - -def ConvertTritonCPUToLLVM : Pass<"convert-triton-cpu-to-llvm", "mlir::ModuleOp"> { - let summary = "Convert TritonCPU to LLVM"; - let description = [{ - - }]; - let constructor = "mlir::triton::createConvertTritonCPUToLLVMPass()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::LLVM::LLVMDialect", - "mlir::math::MathDialect", - "mlir::scf::SCFDialect", - "mlir::tensor::TensorDialect", - "mlir::triton::cpu::TritonCPUDialect", - "mlir::triton::TritonDialect"]; - - let options = [ - ]; -} - -#endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h b/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h deleted file mode 100644 index f5cd3612dac5..000000000000 --- a/include/triton/Conversion/TritonCPUToLLVM/PatternTritonCPUOpToLLVM.h +++ /dev/null @@ -1,44 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H -#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_PATTERNS_TRITON_CPU_OP_TO_LLVM_H - -#include "CPUTargetInfo.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" - -using namespace mlir; -using namespace mlir::triton; - -namespace mlir { -namespace triton { -// Some populate* functions have name collisions with the ones for GPUs. -namespace cpu { - -constexpr int patternBenefitDefault = 1; -constexpr int patternBenefitPrioritizeOverLLVMConversions = 10; -constexpr int patternBenefitClampOptimizedPattern = 20; -constexpr int patternBenefitConvertLayoutOptimizedPattern = 20; - -void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - const cpu::CPUTargetInfo &targetInfo, - PatternBenefit benefit); - -void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - PatternBenefit benefit); - -void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - PatternBenefit benefit); - -void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, - const CPUTargetInfo &targetInfo, - PatternBenefit benefit); - -} // namespace cpu -} // namespace triton -} // namespace mlir - -#endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h b/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h deleted file mode 100644 index 8ed9e6d4d849..000000000000 --- a/include/triton/Conversion/TritonCPUToLLVM/TypeConverter.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef TRITONCPU_CONVERSION_TRITONCPUTOLLVM_TYPECONVERTER_H -#define TRITONCPU_CONVERSION_TRITONCPUTOLLVM_TYPECONVERTER_H - -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Conversion/MLIRTypes.h" -#include "triton/Dialect/TritonCPU/IR/Types.h" - -using namespace mlir; -using namespace mlir::triton; - -class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { -public: - using TypeConverter::convertType; - - TritonCPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option, - const DataLayoutAnalysis *analysis = nullptr); - - Type convertTritonPointerType(triton::PointerType type); -}; - -#endif diff --git a/include/triton/Conversion/TritonCPUToLLVM/Utility.h b/include/triton/Conversion/TritonCPUToLLVM/Utility.h deleted file mode 100644 index 8562271340a1..000000000000 --- a/include/triton/Conversion/TritonCPUToLLVM/Utility.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H -#define TRITON_CONVERSION_TRITONCPU_TO_LLVM_UTILITY_H - -#include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Analysis/Utility.h" -#include "triton/Conversion/MLIRTypes.h" -#include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" -#include "llvm/Support/ErrorHandling.h" - -using namespace mlir; -using namespace mlir::triton; - -// TODO: Do better refactoring. -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" - -#undef DEBUG_TYPE -#define DEBUG_TYPE "ttcpu_to_llvm" - -#endif diff --git a/include/triton/Conversion/TritonToTritonCPU/CMakeLists.txt b/include/triton/Conversion/TritonToTritonCPU/CMakeLists.txt deleted file mode 100644 index 66945e2242f1..000000000000 --- a/include/triton/Conversion/TritonToTritonCPU/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls --name TritonToTritonCPU) -add_public_tablegen_target(TritonConversionToCPUPassIncGen) diff --git a/include/triton/Conversion/TritonToTritonCPU/Passes.h b/include/triton/Conversion/TritonToTritonCPU/Passes.h deleted file mode 100644 index 4ec0411da1ab..000000000000 --- a/include/triton/Conversion/TritonToTritonCPU/Passes.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef TRITON_CONVERSION_TO_CPU_PASSES_H -#define TRITON_CONVERSION_TO_CPU_PASSES_H - -#include "triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h" - -namespace mlir { -namespace triton { - -#define GEN_PASS_REGISTRATION -#include "triton/Conversion/TritonToTritonCPU/Passes.h.inc" - -} // namespace triton -} // namespace mlir - -#endif diff --git a/include/triton/Conversion/TritonToTritonCPU/Passes.td b/include/triton/Conversion/TritonToTritonCPU/Passes.td deleted file mode 100644 index a15bd15bfcd1..000000000000 --- a/include/triton/Conversion/TritonToTritonCPU/Passes.td +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef TRITON_CONVERSION_TO_CPU_PASSES -#define TRITON_CONVERSION_TO_CPU_PASSES - -include "mlir/Pass/PassBase.td" - -def ConvertTritonToTritonCPU: Pass<"convert-triton-to-tritoncpu", "mlir::ModuleOp"> { - let summary = "Convert Triton to TritonCPU"; - let description = [{ - - }]; - let constructor = "mlir::triton::createConvertTritonToTritonCPUPass()"; - - let dependentDialects = ["mlir::arith::ArithDialect", - "mlir::math::MathDialect", - "mlir::scf::SCFDialect", - "mlir::triton::cpu::TritonCPUDialect", - "mlir::triton::TritonDialect"]; - - let options = [ - ]; -} - -#endif diff --git a/include/triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h b/include/triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h deleted file mode 100644 index 2e7acbd24548..000000000000 --- a/include/triton/Conversion/TritonToTritonCPU/TritonToTritonCPUPass.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef TRITON_CONVERSION_TRITONTOTRITONCPU_TRITONTOTRITONCPUPASS_H -#define TRITON_CONVERSION_TRITONTOTRITONCPU_TRITONTOTRITONCPUPASS_H - -#include - -namespace mlir { - -class ModuleOp; -template class OperationPass; - -namespace triton { - -std::unique_ptr> createConvertTritonToTritonCPUPass(); - -} // namespace triton -} // namespace mlir - -#endif diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 1ce9419e8f60..3959bf28f4e1 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -11,8 +11,6 @@ #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "triton/Conversion/TritonCPUToLLVM/Passes.h" -#include "triton/Dialect/TritonCPU/IR/Dialect.h" #include "llvm/IR/Constants.h" #include "llvm/Support/TargetSelect.h" From 75142b04993356a865418aab0ec762994f278fb2 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Tue, 23 Jul 2024 00:13:01 +0200 Subject: [PATCH 051/165] [Scf If types] Support conversion of types for scf::if (#45) This commit copies same approach as it were for `scf::for`. It's aimed to convert types like: `tensor<32xi32>` to `vector<32xi32>`. Maybe there should be some utility that will be used in all such ConversionPasses to avoid code-duplication. Signed-off-by: Dmitrii Makarenko --- python/test/unit/language/test_core.py | 1 + .../ConvertControlFlowOps.cpp | 57 +++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 8ec022110470..88a6d7f1cf80 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6805,6 +6805,7 @@ def kernel(X): raise +@pytest.mark.cpu @pytest.mark.interpreter def test_temp_var_in_loop(device): diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp index 9cf6e31810d7..a0115a897734 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp @@ -70,6 +70,52 @@ struct ForOpConversion : public OpConversionPattern { } }; +// This is borrowed from ConvertFIfOpTypes in +// SCF/Transforms/StructuralTypeConversions.cpp +// and +// lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +class SCFIfPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: Generalize this to any type conversion, not just 1:1. + // + // We need to implement something more sophisticated here that tracks which + // types convert to which other types and does the appropriate + // materialization logic. + // For example, it's possible that one result type converts to 0 types and + // another to 2 types, so newResultTypes would at least be the right size to + // not crash in the llvm::zip call below, but then we would set the the + // wrong type on the SSA values! These edge cases are also why we cannot + // safely use the TypeConverter::convertTypes helper here. + SmallVector newResultTypes; + for (auto type : op.getResultTypes()) { + Type newType = typeConverter->convertType(type); + if (!newType) + return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion"); + newResultTypes.push_back(newType); + } + + // See comments in the ForOp pattern for why we clone without regions and + // then inline. + scf::IfOp newOp = + cast(rewriter.cloneWithoutRegions(*op.getOperation())); + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + // Update the operands and types. + newOp->setOperands(adaptor.getOperands()); + for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) + std::get<0>(t).setType(std::get<1>(t)); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + struct ConvertControlFlowOps : public triton::impl::ConvertControlFlowOpsBase { using ConvertControlFlowOpsBase::ConvertControlFlowOpsBase; @@ -93,6 +139,17 @@ struct ConvertControlFlowOps return signalPassFailure(); } + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); + { + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } + convTarget.addDynamicallyLegalOp( [&](Operation *op) -> std::optional { return typeConverter.isLegal(op); From 1f0ca4a3f04c0262a3f733ef9b22ec14beb3a9a9 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Tue, 23 Jul 2024 03:35:24 +0200 Subject: [PATCH 052/165] [WA for fp16 torch.matmul] Replace torch.matmul with np.matmul (#44) This commit replaces torch.matmul for fp16 cpu case with np.matmul. As there is no implementation for such configuration. Signed-off-by: Dmitrii Makarenko --- python/test/unit/language/test_core.py | 43 ++++++++++++++++---------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 88a6d7f1cf80..ed78db94a658 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4060,15 +4060,6 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") - elif is_cpu(): - input_precision = "ieee" - # TODO(dmitriim): investigate the reason why - # can be fixed with lower tolerance: - # E Mismatched elements: 94 / 32768 (0.287%) - # E Max absolute difference: 0.09375 - # E Max relative difference: 4.812 - if out_dtype_str == "float16" and in_dtype_str == "float16": - pytest.skip(f"{out_dtype_str} with M = {M}, N = {N}, K = {K} has low precision. Not clear why.") else: input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" if not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16): @@ -4178,6 +4169,7 @@ def kernel( np.testing.assert_allclose(out_ref, to_numpy(out_tri), rtol=0.01, atol=1e-2) +@pytest.mark.cpu @pytest.mark.parametrize('in_dtype', ['float32']) def test_dot_mulbroadcasted(in_dtype, device): if is_cuda(): @@ -4359,6 +4351,7 @@ def kernel_constexpr(in_ptr: tl.const, out, c_out: tl.const, choose_const: tl.co assert torch.all(input == output) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str", ['float32', 'float16']) def test_dot_without_load(dtype_str, device): @@ -4374,7 +4367,11 @@ def _kernel(out): kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) - out_ref = torch.matmul(a, b) + if is_cpu() and dtype_str == "float16": + # torch.matmul not implemented for Half float (float16) cpu + out_ref = torch.tensor(np.matmul(to_numpy(a), to_numpy(b)), dtype=getattr(torch, dtype_str), device=device) + else: + out_ref = torch.matmul(a, b) out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) kernel[(1, )](out) assert torch.all(out == out_ref) @@ -4478,6 +4475,7 @@ def kernel(in_ptr, out_ptr, size: tl.constexpr, mask: tl.constexpr, other: tl.co # Testing masked loads with a copy to shared memory. # FIXME: Shape too small for ldmatrix when num_ctas=4 +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) def test_masked_load_shared_memory(dtype, device): @@ -4490,11 +4488,11 @@ def test_masked_load_shared_memory(dtype, device): in1 = torch.rand((M, K), dtype=dtype, device=device) in2 = torch.rand((K, N), dtype=dtype, device=device) - out = torch.zeros((M, N), dtype=dtype, device=device) + out = torch.zeros((M, N), dtype=torch.float32, device=device) @triton.jit - def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, M: tl.constexpr, N: tl.constexpr, + K: tl.constexpr): M_offsets = tl.arange(0, M) N_offsets = tl.arange(0, N) @@ -4514,10 +4512,16 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_ output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) - pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), - out.numel(), M=M, N=N, K=K) + _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], M=M, N=N, K=K) + if is_cpu() and (dtype == torch.float16 or dtype == torch.bfloat16): + # torch.matmul not implemented for Half float (float16) cpu + reference_out = torch.tensor(np.matmul(to_numpy(in1), to_numpy(in2))).to(torch.float32) + # f32_in1 = convert_float_to_float32(in1) + # f32_in2 = convert_float_to_float32(in2) + # reference_out = torch.matmul(f32_in1, f32_in2) + else: + reference_out = torch.matmul(in1, in2).to(torch.float32) - reference_out = torch.matmul(in1, in2) torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) @@ -6714,6 +6718,7 @@ def loop_kernel(Z, N: tl.constexpr, step: tl.constexpr): assert (Out == Acc).all(), (Out, Acc) +@pytest.mark.cpu @pytest.mark.interpreter def test_tl_range_num_stages(device): if is_hip(): @@ -6727,7 +6732,11 @@ def test_tl_range_num_stages(device): 1, ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, 0, num_stages=5) - ref_out = torch.matmul(a, b).to(torch.float32) + if is_cpu(): + # torch.matmul not implemented for Half float (float16) cpu + ref_out = torch.tensor(np.matmul(to_numpy(a), to_numpy(b))).to(torch.float32) + else: + ref_out = torch.matmul(a, b).to(torch.float32) if is_interpreter(): # GPU invokes tensor core for float16 matmul, which is not supported in interpreter. # Thus we use a higher tolerance From dd5e3e268eef6180c2878533268852cd29c50277 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 23 Jul 2024 11:05:44 -0400 Subject: [PATCH 053/165] [cpu] Have MulhiUI lowering support scalars (#64) This makes most of the int32 tests in test_random.py pass. (int64 ones are still failing.) Fixes https://github.com/triton-lang/triton-cpu/issues/62. --- .../ConvertElementwiseOps.cpp | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 5a66813b0a16..0a209b5a8b25 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -85,28 +86,35 @@ struct MulhiUIOpConversion : public OpConversionPattern { LogicalResult matchAndRewrite(triton::MulhiUIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(isa(op.getType())); auto loc = op.getLoc(); auto lhs = rewriter.getRemappedValue(op.getX()); auto rhs = rewriter.getRemappedValue(op.getY()); - auto lhsTy = dyn_cast(lhs.getType()); - auto rhsTy = dyn_cast(rhs.getType()); - auto vecI32Ty = lhsTy.cloneWith(std::nullopt, rewriter.getI32Type()); - auto vecI64Ty = lhsTy.cloneWith(std::nullopt, rewriter.getI64Type()); - assert(lhsTy.getElementType().isInteger()); - assert(rhsTy.getElementType().isInteger()); - // Cast to int64 - if (lhsTy.getElementTypeBitWidth() < 64) { - lhs = rewriter.create(loc, vecI64Ty, lhs); + + Type extUITy = rewriter.getI64Type(); + Type truncITy = rewriter.getI32Type(); + Value cst32; + if (auto lhsTy = dyn_cast(lhs.getType())) { + assert(isa(rhs.getType())); + extUITy = lhsTy.cloneWith(std::nullopt, extUITy); + truncITy = lhsTy.cloneWith(std::nullopt, truncITy); + cst32 = rewriter.create( + loc, DenseElementsAttr::get(dyn_cast(extUITy), 32LL)); + } else { + cst32 = rewriter.create( + loc, rewriter.getI64IntegerAttr(32LL)); + } + + auto lhsTy = getElementTypeOrSelf(lhs.getType()); + auto rhsTy = getElementTypeOrSelf(rhs.getType()); + if (lhsTy.getIntOrFloatBitWidth() < 64) { + lhs = rewriter.create(loc, extUITy, lhs); } - if (rhsTy.getElementTypeBitWidth() < 64) { - rhs = rewriter.create(loc, vecI64Ty, rhs); + if (rhsTy.getIntOrFloatBitWidth() < 64) { + rhs = rewriter.create(loc, extUITy, rhs); } Value res = rewriter.create(loc, lhs, rhs); - Value cst32 = rewriter.create( - loc, DenseElementsAttr::get(vecI64Ty, 32LL)); res = rewriter.create(loc, res, cst32); - res = rewriter.create(loc, vecI32Ty, res); + res = rewriter.create(loc, truncITy, res); rewriter.replaceOp(op, res); return success(); } From c4ea76157668f490de876035cf4616776fa7a8bc Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 23 Jul 2024 12:57:27 -0400 Subject: [PATCH 054/165] [cpu] Fix formatting (#65) Our pre-commit hook is failing --- python/test/unit/language/test_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index ed78db94a658..1083647132d5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -319,8 +319,10 @@ def kernel(X, SIZE: tl.constexpr): x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) kernel[(1, )](x, SIZE=SIZE, num_warps=4) + @pytest.mark.cpu def test_empty_kernel_scalar_arg(device): + @triton.jit def kernel(x): pass From fe9f0cd8a28d4a332a959dae4a206043d498dd39 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 23 Jul 2024 16:44:51 -0400 Subject: [PATCH 055/165] [cpu] Support tl.load(..., padding="nan") (#69) This makes test_block_pointer.py pass. Fixes #60. --- .github/workflows/build-test.yml | 4 ++- .../TritonToTritonCPU/ConvertMemoryOps.cpp | 30 +++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index a2cdc22e1920..88ec63c6c2d0 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -79,4 +79,6 @@ jobs: - name: Run python unit tests run: | python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu - python -m pytest -s -n 32 --device cpu python/test/unit/language/test_conversions.py + python -m pytest -s -n 32 --device cpu \ + python/test/unit/language/test_block_pointer.py \ + python/test/unit/language/test_conversions.py diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 2787247a731c..2bd3f2c87f6f 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -96,6 +96,30 @@ struct MemoryOpConversion : public OpConversionPattern { struct LoadOpConversion : public MemoryOpConversion { using MemoryOpConversion::MemoryOpConversion; + static Value + getPaddingValue(Location loc, Type type, + const std::optional &padding, + ConversionPatternRewriter &rewriter) { + if (!padding.has_value()) + return Value(); + + TypedAttr attr; + switch (padding.value()) { + case triton::PaddingOption::PAD_ZERO: + attr = type.isIntOrIndex() ? cast(IntegerAttr::get(type, 0)) + : cast(FloatAttr::get(type, 0)); + break; + case triton::PaddingOption::PAD_NAN: + assert(!type.isIntOrIndex()); + auto apNaN = llvm::APFloat::getNaN( + FloatAttr::get(type, 0).getValue().getSemantics()); + attr = FloatAttr::get(type, apNaN); + break; + } + + return rewriter.create(loc, attr); + } + LogicalResult matchAndRewrite(triton::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -126,8 +150,10 @@ struct LoadOpConversion : public MemoryOpConversion { for (auto dim : boundaryChecks) { inBounds[dim] = false; } - auto vecRead = rewriter.create(loc, resTy, memRef, - indices, inBounds); + Value padding = getPaddingValue(loc, resTy.getElementType(), + loadOp.getPadding(), rewriter); + auto vecRead = rewriter.create( + loc, resTy, memRef, indices, padding, inBounds); rewriter.replaceOp(loadOp, vecRead); return success(); } From 60690e93ff63b5ebe438faa7854227ff5f6670e2 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 23 Jul 2024 16:45:38 -0400 Subject: [PATCH 056/165] [cpu] Use helpers from OptCommon.h to simplify code (#67) Per https://github.com/triton-lang/triton-cpu/pull/64#issuecomment-2245758319 --- .../TritonCPUTransforms/OptCommon.h | 4 ++++ .../ConvertUnsupportedOps.cpp | 3 +-- .../DecomposeFpConversions.cpp | 3 +-- .../ConvertElementwiseOps.cpp | 18 ++++-------------- 4 files changed, 10 insertions(+), 18 deletions(-) rename third_party/cpu/{lib => include}/TritonCPUTransforms/OptCommon.h (98%) diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h similarity index 98% rename from third_party/cpu/lib/TritonCPUTransforms/OptCommon.h rename to third_party/cpu/include/TritonCPUTransforms/OptCommon.h index ebf4ae723248..0fe6dc64c5b2 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/OptCommon.h +++ b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h @@ -56,6 +56,10 @@ inline Type toInt32(Type ty) { return toTyOrVectorOf(ty, IntegerType::get(ty.getContext(), 32)); } +inline Type toInt64(Type ty) { + return toTyOrVectorOf(ty, IntegerType::get(ty.getContext(), 64)); +} + inline Type toFp8E5M2(Type ty) { return toTyOrVectorOf(ty, Float8E5M2Type::get(ty.getContext())); } diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index 5d991b376902..ccd36a4c03ba 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -1,5 +1,4 @@ -#include "OptCommon.h" - +#include "cpu/include/TritonCPUTransforms/OptCommon.h" #include "cpu/include/TritonCPUTransforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp index 02d7087bf0a2..df1d0e34cd87 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -1,5 +1,4 @@ -#include "OptCommon.h" - +#include "cpu/include/TritonCPUTransforms/OptCommon.h" #include "cpu/include/TritonCPUTransforms/Passes.h" #include "mlir/Dialect/Utils/IndexingUtils.h" diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 0a209b5a8b25..ff7bd29e0818 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -1,6 +1,7 @@ #include "OpTypeConversion.h" #include "TypeConverter.h" +#include "cpu/include/TritonCPUTransforms/OptCommon.h" #include "cpu/include/TritonToTritonCPU/Passes.h" #include "mlir/Analysis/DataFlowFramework.h" @@ -90,20 +91,9 @@ struct MulhiUIOpConversion : public OpConversionPattern { auto lhs = rewriter.getRemappedValue(op.getX()); auto rhs = rewriter.getRemappedValue(op.getY()); - Type extUITy = rewriter.getI64Type(); - Type truncITy = rewriter.getI32Type(); - Value cst32; - if (auto lhsTy = dyn_cast(lhs.getType())) { - assert(isa(rhs.getType())); - extUITy = lhsTy.cloneWith(std::nullopt, extUITy); - truncITy = lhsTy.cloneWith(std::nullopt, truncITy); - cst32 = rewriter.create( - loc, DenseElementsAttr::get(dyn_cast(extUITy), 32LL)); - } else { - cst32 = rewriter.create( - loc, rewriter.getI64IntegerAttr(32LL)); - } - + Type extUITy = toInt64(lhs.getType()); + Type truncITy = toInt32(lhs.getType()); + Value cst32 = intCst(loc, extUITy, 32LL, rewriter); auto lhsTy = getElementTypeOrSelf(lhs.getType()); auto rhsTy = getElementTypeOrSelf(rhs.getType()); if (lhsTy.getIntOrFloatBitWidth() < 64) { From bd0ea60fe959419e0c91d0f75a733a1a22fd554f Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 23 Jul 2024 19:24:58 -0400 Subject: [PATCH 057/165] [cpu] Follow up to #69 (#70) --- .../lib/TritonToTritonCPU/ConvertMemoryOps.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 2bd3f2c87f6f..8896ab5ebd8d 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -100,19 +100,17 @@ struct LoadOpConversion : public MemoryOpConversion { getPaddingValue(Location loc, Type type, const std::optional &padding, ConversionPatternRewriter &rewriter) { - if (!padding.has_value()) - return Value(); + auto padding_option = padding.value_or(PaddingOption::PAD_ZERO); TypedAttr attr; - switch (padding.value()) { - case triton::PaddingOption::PAD_ZERO: - attr = type.isIntOrIndex() ? cast(IntegerAttr::get(type, 0)) - : cast(FloatAttr::get(type, 0)); + switch (padding_option) { + case PaddingOption::PAD_ZERO: + attr = rewriter.getZeroAttr(type); break; - case triton::PaddingOption::PAD_NAN: + case PaddingOption::PAD_NAN: assert(!type.isIntOrIndex()); - auto apNaN = llvm::APFloat::getNaN( - FloatAttr::get(type, 0).getValue().getSemantics()); + auto apNaN = + llvm::APFloat::getNaN(cast(type).getFloatSemantics()); attr = FloatAttr::get(type, apNaN); break; } From 9508a2f03041a91fd4661fd079107bd967155bec Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 25 Jul 2024 13:19:12 -0400 Subject: [PATCH 058/165] [cpu] Add runtime library for CPU kernels (#73) This will make it simpler to implement device_assert, device_print etc. --- python/triton/runtime/build.py | 2 ++ third_party/cpu/CMakeLists.txt | 1 + third_party/cpu/backend/compiler.py | 6 +++--- third_party/cpu/backend/driver.py | 15 ++++++++++----- third_party/cpu/runtime/cpu_runtime.cpp | 6 ++++++ 5 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 third_party/cpu/runtime/cpu_runtime.cpp diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 44a8baefe65e..a489c7df3317 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -52,6 +52,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] + for dir in library_dirs: + cc_cmd.extend(["-rpath", dir]) # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. if src.endswith(".cpp") or src.endswith(".cc"): cc_cmd += ["-std=c++17", "-fopenmp"] diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index b107e2434e1e..c0bbcbfca1f6 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -5,4 +5,5 @@ add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm) + add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) endif() diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6e8a3b6c329c..024e8f252ef6 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -182,9 +182,9 @@ def make_so(src, metadata, options): "kernel", asm_path, tmpdir, - cpu_driver.library_dir, - cpu_driver.include_dir, - ["gcc", "m"], + cpu_driver.library_dirs, + cpu_driver.include_dirs, + ["gcc", "m", "TritonCPURuntime"], ) with open(so, "rb") as f: return f.read() diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index a5eef2ed9fb8..fbeb129e0c15 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,15 +1,20 @@ import os import hashlib +import importlib import tempfile -from pathlib import Path + +import triton._C from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget -dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") -include_dir = [os.path.join(dirname, "include")] -library_dir = [os.path.join(dirname, "lib")] +_dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") +# for locating libTritonCPURuntime +_triton_C_dir = importlib.resources.files(triton._C).joinpath("") + +include_dirs = [os.path.join(_dirname, "include")] +library_dirs = [os.path.join(_dirname, "lib"), _triton_C_dir] libraries = ["stdc++"] @@ -22,7 +27,7 @@ def compile_module_from_src(src, name): src_path = os.path.join(tmpdir, "main.cpp") with open(src_path, "w") as f: f.write(src) - so = _build(name, src_path, tmpdir, library_dir, include_dir, libraries) + so = _build(name, src_path, tmpdir, library_dirs, include_dirs, libraries) with open(so, "rb") as f: cache_path = cache.put(f.read(), f"{name}.so", binary=True) import importlib.util diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp new file mode 100644 index 000000000000..0d69ca6c8ab7 --- /dev/null +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -0,0 +1,6 @@ +#include + +void triton_assert(bool cond, char *c) { + if (!cond) + fprintf(stderr, "%s\n", c); +} From cf21e44cd023d437a2cefb3fc4913527a209be65 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Thu, 25 Jul 2024 23:19:27 +0200 Subject: [PATCH 059/165] [FP8 tests] Enable several fp8 tests (#49) This commit enables several fp8 tests, that uses `tl.range`. Signed-off-by: Dmitrii Makarenko --- python/test/unit/language/test_core.py | 38 ++++++++++++++----- .../ConvertUnsupportedOps.cpp | 23 +++++++++++ 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 1083647132d5..f465e2b1e674 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -126,6 +126,8 @@ def check_type_supported(dtype, device): if is_interpreter(): if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]: pytest.skip("bfloat16 is not supported in the interpreter") + if dtype == 'float8e4b15' and is_cpu(): + pytest.skip("float8e4b15 not supported on CPU") class MfmaLayout: @@ -1220,6 +1222,7 @@ def test_abs(dtype_x, device): _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("in_dtype", [tl.float8e4b15, tl.float8e4nv, tl.float8e5]) def test_abs_fp8(in_dtype, device): @@ -5714,6 +5717,7 @@ def kernel(Out): # ----------------------- +@pytest.mark.cpu def test_num_threads(device): if is_hip(): pytest.skip("test_num_threads is not supported in HIP") @@ -6430,6 +6434,9 @@ def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.co def f8_to_f16(x, dtype): + if is_cpu(): + assert (False and "Works as expected only for GPU") + @triton.jit def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) @@ -6456,10 +6463,9 @@ def matmul_kernel( # low_precision_acc: tl.constexpr, # num_stages: tl.constexpr = 3 # ): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - pid_m = pid % num_pid_m - pid_n = pid // num_pid_m + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) @@ -6478,6 +6484,7 @@ def matmul_kernel( # tl.store(c_ptrs, accumulator) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("M, N, K", [(128, 256, 256)]) @pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)]) @@ -6501,16 +6508,23 @@ def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_s num_warps = 8 a = to_triton(A, device=device, dst_type=in_type_str) b = to_triton(B, device=device, dst_type=in_type_str) - grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) max_num_impressive_acc = low_precision_acc if low_precision_acc <= BLOCK_K else None h = matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, max_num_impressive_acc, num_warps=num_warps, num_stages=num_stages) torch_a = torch.from_numpy(A).to(device=device) - th_a = f8_to_f16(torch_a, in_type_str) torch_b = torch.from_numpy(B).to(device=device) - th_b = f8_to_f16(torch_b, in_type_str) - ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if is_cpu() and 'float8' in in_type_str: + in_dtype = getattr(tl, in_type_str) + th_a = convert_float_to_float32(torch_a, in_dtype) + th_b = convert_float_to_float32(torch_b, in_dtype) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + else: + th_a = f8_to_f16(torch_a, in_type_str) + th_b = f8_to_f16(torch_b, in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) else: @@ -6726,7 +6740,11 @@ def test_tl_range_num_stages(device): if is_hip(): pytest.skip("test_tl_range is not supported in HIP") M, N, K = 64, 64, 512 - BLOCK_M, BLOCK_N, BLOCK_K = M, N, 64 + if is_cpu(): + block_m, block_n, block_k = 32, 32, 64 + else: + block_m, block_n, block_k = M, N, 64 + BLOCK_M, BLOCK_N, BLOCK_K = block_m, block_n, block_k a = torch.randn((M, K), device=device, dtype=torch.float16) b = torch.randn((K, N), device=device, dtype=torch.float16) c = torch.empty((M, N), dtype=torch.float32, device=device) @@ -6736,7 +6754,7 @@ def test_tl_range_num_stages(device): BLOCK_K, 0, num_stages=5) if is_cpu(): # torch.matmul not implemented for Half float (float16) cpu - ref_out = torch.tensor(np.matmul(to_numpy(a), to_numpy(b))).to(torch.float32) + ref_out = torch.tensor(np.matmul(to_numpy(a), to_numpy(b)), dtype=torch.float32, device=device) else: ref_out = torch.matmul(a, b).to(torch.float32) if is_interpreter(): diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index ccd36a4c03ba..601edbdb5782 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -166,6 +166,28 @@ struct ConvertBf16Abs : public OpRewritePattern { } }; +struct ConvertF8Abs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::AbsFOp op, + PatternRewriter &rewriter) const override { + if (!isFp8(op.getType()) || !isFp8(op.getOperand().getType())) + return failure(); + + Location loc = op.getLoc(); + Value src = op.getOperand(); + Type srcType = op.getType(); + + Value i8Src = op_bitcast(toInt8(srcType), src); + // Mask out the sign bit + Value nosign = op_and(i8Src, cst_like(i8Src, 0x7f)); + Value res = op_bitcast(srcType, nosign); + + rewriter.replaceOp(op, res); + return success(); + } +}; + struct ConvertMixedPrecisionMatmul : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -253,6 +275,7 @@ struct ConvertUnsupportedOps patterns.add(context); patterns.add(context); } + patterns.add(context); if (convertMixedPrecisionMatmul) { patterns.add(context); } From 34cd5d4049990aa4b44ffb72fed9c53530f3ee12 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 25 Jul 2024 17:38:47 -0400 Subject: [PATCH 060/165] [cpu] Make runtime library build on Linux too (#75) macOS gcc supports `-rpath` directly in the compilation driver, but on Linux we need to pass it directly to the linker. Fortunately this works on macOS too. --- python/triton/runtime/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index a489c7df3317..72a66ddec32f 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -53,7 +53,7 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] for dir in library_dirs: - cc_cmd.extend(["-rpath", dir]) + cc_cmd.extend(["-Wl,-rpath", dir]) # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. if src.endswith(".cpp") or src.endswith(".cc"): cc_cmd += ["-std=c++17", "-fopenmp"] From ffc885ae1cc0f6970fdb750e0a3de3c18bf4c8c2 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Fri, 26 Jul 2024 09:32:55 -0400 Subject: [PATCH 061/165] [cpu] Get more of test_random.py working (#77) Two changes: 1. Lower mulhiui to mulhi_extended rather than a bunch of ext / trunc instructions. This greatly simplifies the code, as well as automatically gets it working for 64-bit inputs, which the previous implementation did not handle correctly 2. Only emit extsi if the result bitwidth is larger than the input bitwidth. Otherwise it fails validation. This gets the int64 tests in test_random.py passing. Fixes https://github.com/triton-lang/triton-cpu/issues/71. --- .../ConvertElementwiseOps.cpp | 18 +++--------------- .../lib/TritonToTritonCPU/ConvertPtrOps.cpp | 4 +++- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index ff7bd29e0818..8c2377babbac 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -5,6 +5,7 @@ #include "cpu/include/TritonToTritonCPU/Passes.h" #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -90,21 +91,8 @@ struct MulhiUIOpConversion : public OpConversionPattern { auto loc = op.getLoc(); auto lhs = rewriter.getRemappedValue(op.getX()); auto rhs = rewriter.getRemappedValue(op.getY()); - - Type extUITy = toInt64(lhs.getType()); - Type truncITy = toInt32(lhs.getType()); - Value cst32 = intCst(loc, extUITy, 32LL, rewriter); - auto lhsTy = getElementTypeOrSelf(lhs.getType()); - auto rhsTy = getElementTypeOrSelf(rhs.getType()); - if (lhsTy.getIntOrFloatBitWidth() < 64) { - lhs = rewriter.create(loc, extUITy, lhs); - } - if (rhsTy.getIntOrFloatBitWidth() < 64) { - rhs = rewriter.create(loc, extUITy, rhs); - } - Value res = rewriter.create(loc, lhs, rhs); - res = rewriter.create(loc, res, cst32); - res = rewriter.create(loc, truncITy, res); + Value res = + rewriter.create(loc, lhs, rhs).getHigh(); rewriter.replaceOp(op, res); return success(); } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp index 82123c376dc1..27f49a3078c1 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp @@ -118,6 +118,7 @@ struct AddPtrOpConversion : public OpConversionPattern { assert(isa(offset.getType())); assert(isa(ptr.getType())); VectorType offsetTy = cast(offset.getType()); + VectorType ptrTy = cast(ptr.getType()); // Build scale vector. i1 elements take 1 byte. Value scale = rewriter.create( loc, offsetTy, @@ -125,7 +126,8 @@ struct AddPtrOpConversion : public OpConversionPattern { offsetTy, rewriter.getIntegerAttr(offsetTy.getElementType(), (elemBitWidth + 7) / 8))); offset = rewriter.create(loc, offset, scale); - offset = rewriter.create(loc, ptr.getType(), offset); + if (offsetTy.getElementTypeBitWidth() < ptrTy.getElementTypeBitWidth()) + offset = rewriter.create(loc, ptr.getType(), offset); rewriter.replaceOpWithNewOp(op, ptr.getType(), ptr, offset); return success(); } From bb57572bce98c712d7f4780f0a039e8a35228aba Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Mon, 29 Jul 2024 19:11:35 +0200 Subject: [PATCH 062/165] [FIX Pytest] Resolve 'importlib' issue (#78) This commit resolves issue, that occurs on main 'pytest run'. ``` _triton_C_dir = importlib.resources.files(triton._C).joinpath("") E AttributeError: module 'importlib' has no attribute 'resources' ``` --- third_party/cpu/backend/driver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index fbeb129e0c15..2cedbcae99ba 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -1,6 +1,7 @@ import os import hashlib import importlib +import importlib.resources import tempfile import triton._C From 679a88be2fbbd6816c29fb11754eb268c0374ec4 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Mon, 29 Jul 2024 15:35:10 -0400 Subject: [PATCH 063/165] Fix importlib issues (#80) The previous implementation worked on Python 3.11 but had a host of issues with other versions. It boiled down to `triton._C` not being a regular package -- it doesn't have an `__init__.py` file. However, it can still be imported as a [namespace package][1]. Namespace packages can map to multiple locations on the filesystem, so we cannot get a path to the package contents without materializing the package. The solution is to look up the files of the top-level `triton` package instead, which is a regular package, and use that to find the location of the `_C` directory. I've tested that this approach works on 3.9 and 3.12 in macOS. Fixes #76. [1]: https://packaging.python.org/en/latest/guides/packaging-namespace-packages/#native-namespace-packages --- third_party/cpu/backend/driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 2cedbcae99ba..a9eb00a874f3 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -12,7 +12,7 @@ _dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") # for locating libTritonCPURuntime -_triton_C_dir = importlib.resources.files(triton._C).joinpath("") +_triton_C_dir = importlib.resources.files(triton).joinpath("_C") include_dirs = [os.path.join(_dirname, "include")] library_dirs = [os.path.join(_dirname, "lib"), _triton_C_dir] From b3e65975599931f52923c1fb079cdbc5e41fbbc9 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 30 Jul 2024 17:17:55 -0400 Subject: [PATCH 064/165] [cpu] Add test_annotations.py to CI (#81) The only test that needed fixing was `test_unknown_annotations`, where we were generating invalid code for the launcher. In particular, when `kernel_fn_args` was empty, we would get the following error: ``` /var/folders/_z/88s630fd3d9fx72mbmx90qvw0000gn/T/tmpy481mz0l/main.cpp:37:29: error: expected ';' before '(' token 37 | using kernel_ptr_t = void(*)(, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t); | ^ | ; ``` --- .github/workflows/build-test.yml | 1 + third_party/cpu/backend/driver.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 88ec63c6c2d0..0de19f13c3cf 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -80,5 +80,6 @@ jobs: run: | python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu python -m pytest -s -n 32 --device cpu \ + python/test/unit/language/test_annotations.py \ python/test/unit/language/test_block_pointer.py \ python/test/unit/language/test_conversions.py diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index a9eb00a874f3..08400174bc09 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -122,11 +122,10 @@ def format_of(ty): args_format = ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) format = "iiiOKOOOO" + args_format - arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) if len(signature) > 0 else '' + arg_ptrs_list = ', '.join(f"&arg{i}" for i, ty in signature.items()) kernel_fn_args = [i for i in signature.keys() if i not in constants] - kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) if len(kernel_fn_args) > 0 else '' - kernel_fn_arg_types = (', '.join(f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args) + ", " - if len(signature) > 0 else '') + "uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, uint32_t" + kernel_fn_args_list = ', '.join(f"arg{i}" for i in kernel_fn_args) + kernel_fn_arg_types = ', '.join([f"{ty_to_cpp(signature[i])}" for i in kernel_fn_args] + ["uint32_t"] * 6) # generate glue code src = f""" From 21ab56bd2846010e57d8e769c6892b14dbe49c3b Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 31 Jul 2024 20:36:40 -0500 Subject: [PATCH 065/165] Reduce/disable some tests on CPU for faster CI runs. (#83) * Reduce/disable some tests on CPU for faster CI runs. Signed-off-by: Ilya Enkovich * Add bf16 conversion for transfer reads/writes. Signed-off-by: Ilya Enkovich * [WA16fp] Fix Test block pointer This commit fixes issue with `test_block_ptr_matmul_no_scf`. Signed-off-by: Dmitrii Makarenko * Reduce problem size for test_block_ptr_matmul_no_scf on CPU. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich Signed-off-by: Dmitrii Makarenko Co-authored-by: Dmitrii Makarenko --- .../test/unit/language/test_block_pointer.py | 11 ++- python/test/unit/language/test_core.py | 12 +++- .../ConvertUnsupportedOps.cpp | 70 ++++++++++++++++--- 3 files changed, 80 insertions(+), 13 deletions(-) diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index aff7a29d8781..1f2f5b5e995f 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -3,7 +3,7 @@ import triton import triton.language as tl -from test_core import check_type_supported +from test_core import check_type_supported, is_cpu @triton.jit @@ -101,6 +101,9 @@ def matmul_no_scf_with_advance_kernel( # ]) def test_block_ptr_matmul_no_scf(shape, num_warps, device): m, n, k = shape + if is_cpu(): + # FIXME: fix compilation time for bigger shapes on CPU + m = n = 16 a = torch.randn((m, k), device=device, dtype=torch.float16) b = torch.randn((k, n), device=device, dtype=torch.float16) c = torch.empty((m, n), device=device, dtype=torch.float32) @@ -114,5 +117,9 @@ def test_block_ptr_matmul_no_scf(shape, num_warps, device): stride_cm=c.stride(0), stride_cn=c.stride(1), # BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # num_warps=num_warps) - golden = torch.matmul(a, b) + if is_cpu(): + # torch.matmul not implemented for Half float (float16) cpu + golden = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + else: + golden = torch.matmul(a, b) torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f465e2b1e674..69abbbf0a7c6 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3331,6 +3331,10 @@ def test_permute(dtype_str, shape, perm, num_ctas, device): if shape == (128, 128) and dtype_str == 'float32': pytest.skip("TODO Out of LDS for float32 with shape 128x128") + if is_cpu(): + # FIXME: compilation time for big shapes is too long + shape = tuple(dim // 4 for dim in shape) + # triton kernel @triton.jit def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): @@ -3562,9 +3566,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty # for bigger sizes with the current codegen on CPU. Limit input sizes # by default to get more reasonable tests execution time. if os.environ.get('TRITON_CPU_TEST_DOT_FULL_SIZE', '0') != '1': - M = min(M, 64) - N = min(N, 64) - K = min(K, 32) + M = min(M, 32 if epilogue == "chain-dot" else 64) + N = min(N, 32 if epilogue == "chain-dot" else 64) + K = min(K, 16 if epilogue == "chain-dot" else 32) else: if not is_hip() and (M < 16 or N < 16 or K < 16): pytest.skip("small dots are supported only on HIP at the moment") @@ -4065,6 +4069,8 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ pytest.skip(f"{in_dtype_str} is not supported in WMMA dot, FMA does not support dot3d") if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") + elif is_cpu(): + pytest.skip("Test is skipped due to too long execution time on CPU") else: input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" if not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16): diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index 601edbdb5782..d568a7bd5bb2 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -87,18 +87,24 @@ struct ConvertIToBf16ToFp32 : public OpRewritePattern { }; Value convertMemRefToI16(Value memRef, PatternRewriter &rewriter) { - // Memory references for masked operations are always built - // with PtrToMemRefOp. - auto def = memRef.getDefiningOp(); - assert(def); - auto insPoint = rewriter.saveInsertionPoint(); - rewriter.setInsertionPointAfter(def); + Value res; MemRefType memRefTy = cast(memRef.getType()); Type newMemRefTy = MemRefType::get(memRefTy.getShape(), rewriter.getI16Type(), memRefTy.getLayout(), memRefTy.getMemorySpace()); - Value res = rewriter.create(memRef.getLoc(), newMemRefTy, - def.getSrc()); + auto insPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfter(memRef.getDefiningOp()); + // Memory references for masked operations and transfers are always built + // with PtrToMemRefOp or ExtractMemRefOp. + if (auto castOp = memRef.getDefiningOp()) { + res = rewriter.create(memRef.getLoc(), newMemRefTy, + castOp.getSrc()); + } else { + auto extractOp = memRef.getDefiningOp(); + assert(extractOp && "Unexpected memref producer"); + res = rewriter.create(memRef.getLoc(), newMemRefTy, + extractOp.getSrc()); + } rewriter.restoreInsertionPoint(insPoint); return res; } @@ -143,6 +149,52 @@ struct ConvertBf16MaskedStoreOp } }; +struct ConvertBf16TransferReadOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value newSource = convertMemRefToI16(op.getSource(), rewriter); + Value newPadding = + op.getPadding() + ? rewriter.create( + loc, toInt16(op.getPadding().getType()), op.getPadding()) + : nullptr; + Value intVal = rewriter.create( + loc, toInt16(op.getType()), newSource, op.getIndices(), + op.getPermutationMapAttr(), newPadding, op.getMask(), + op.getInBoundsAttr()); + Value res = rewriter.create(loc, op.getType(), intVal); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertBf16TransferWriteOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getVector().getType())) + return failure(); + + Location loc = op.getLoc(); + Value newSource = convertMemRefToI16(op.getSource(), rewriter); + Value intVal = rewriter.create( + loc, toInt16(op.getVector().getType()), op.getVector()); + rewriter.replaceOpWithNewOp( + op, intVal, newSource, op.getIndices(), op.getPermutationMapAttr(), + op.getMask(), op.getInBoundsAttr()); + return success(); + } +}; + struct ConvertBf16Abs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -273,6 +325,8 @@ struct ConvertUnsupportedOps patterns.add>(context); patterns.add(context); patterns.add(context); + patterns.add(context); + patterns.add(context); patterns.add(context); } patterns.add(context); From af0fc06e8fe737e07296307ec1fbfaa1e9a386e9 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Sun, 4 Aug 2024 21:24:33 -0400 Subject: [PATCH 066/165] [cpu] Don't reuse shuffle dummies (#88) This results in the following compile-time assertion error (in debug Triton builds): Assertion `Index < size() && "invalid index for value range"' failed. This occurs when there is more than one tt.reduce call with a given number of arguments in the same function, with the later call using more arguments. Reusing the dummy values means that the subsequent call has fewer dummy values than expected, hence the error. This bug also resulted in type mismatches errors between the reused dummy value and the current input value. --- .github/workflows/build-test.yml | 9 +++++++++ test/TritonCPU/reduction.mlir | 18 ++++++++++++++++++ .../TritonToTritonCPU/ConvertReductionOp.cpp | 2 +- .../lib/TritonToTritonCPU/ConvertScanOp.cpp | 2 +- .../lib/TritonToTritonCPU/ReduceScanCommon.h | 18 ++++++++---------- 5 files changed, 37 insertions(+), 12 deletions(-) create mode 100644 test/TritonCPU/reduction.mlir diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 0de19f13c3cf..6967e13b8eb3 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -83,3 +83,12 @@ jobs: python/test/unit/language/test_annotations.py \ python/test/unit/language/test_block_pointer.py \ python/test/unit/language/test_conversions.py + + - name: Run lit tests + run: | + cd python + LIT_TEST_DIR="build/$(ls build | grep -i cmake)/test" + if [ ! -d "${LIT_TEST_DIR}" ]; then + echo "Could not find '${LIT_TEST_DIR}'" ; exit -1 + fi + lit -v "${LIT_TEST_DIR}/TritonCPU" diff --git a/test/TritonCPU/reduction.mlir b/test/TritonCPU/reduction.mlir new file mode 100644 index 000000000000..b3c1430e7b41 --- /dev/null +++ b/test/TritonCPU/reduction.mlir @@ -0,0 +1,18 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-reduction -canonicalize + +// Regression test: Check that we handle consecutive calls to tt.reduce with +// different types & number of arguments. + +module { + tt.func public @triton_(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xi32>) { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32, %arg4: f32): + tt.reduce.return %arg3 : f32 + }) : (tensor<1x4xf32>) -> tensor<1xf32> + %1:2 = "tt.reduce"(%arg0, %arg1) <{axis = 1 : i32}> ({ + ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): + tt.reduce.return %arg3, %arg4 : f32, i32 + }) : (tensor<1x4xf32>, tensor<1x4xi32>) -> (tensor<1xf32>, tensor<1xi32>) + tt.return + } +} diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp index d3a76d9a841b..e660edaf97a5 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -66,7 +66,7 @@ struct ReduceOpConversion SmallVector range(vecSize); std::iota(range.begin(), range.end(), 0); - ArrayRef dummies = createShuffleDummies(loc, inputs, rewriter); + SmallVector dummies = createShuffleDummies(loc, inputs, rewriter); SmallVector res = inputs; for (int64_t stride = vecSize / 2; stride > 0; stride = stride / 2) { SmallVector shuffleIndices = range; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp index 5425b5dbf800..fef15b046621 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp @@ -52,7 +52,7 @@ struct ScanOpConversion int64_t vecSize = cast(inputs[0].getType()).getShape()[0]; Type maskTy = VectorType::get(vecSize, rewriter.getI1Type()); - ArrayRef dummies = createShuffleDummies(loc, inputs, rewriter); + SmallVector dummies = createShuffleDummies(loc, inputs, rewriter); SmallVector res = inputs; for (int64_t stride = 1; stride < vecSize; stride *= 2) { SmallVector shuffleIndices(vecSize, 0); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h index ba2d64d8f5f0..2a00f087125b 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h +++ b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h @@ -221,24 +221,22 @@ struct ReduceScanOpConversionBase : public OpConversionPattern { // Dummy vectors are required for shuffles that cannot work on a single // vector. - ArrayRef + SmallVector createShuffleDummies(Location loc, ValueRange inputs, ConversionPatternRewriter &rewriter) const { - if (shuffleDummies.empty()) { - SmallVector dummyShape({1}); - for (auto val : inputs) { - auto ty = cast(val.getType()); - shuffleDummies.push_back(rewriter.create( - loc, rewriter.getZeroAttr( - ty.cloneWith(dummyShape, ty.getElementType())))); - } + SmallVector shuffleDummies; + SmallVector dummyShape({1}); + for (auto val : inputs) { + auto ty = cast(val.getType()); + shuffleDummies.push_back(rewriter.create( + loc, + rewriter.getZeroAttr(ty.cloneWith(dummyShape, ty.getElementType())))); } return shuffleDummies; } private: mutable IRMapping invariantsMap; - mutable SmallVector shuffleDummies; }; } // namespace cpu From c8b43fec82a3b0abca393589f0605d5777de3b95 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 5 Aug 2024 18:51:20 -0500 Subject: [PATCH 067/165] Utilize vector math functions from libmvec. (#55) * Lower vector math operations to libmvec calls. Signed-off-by: Ilya Enkovich * Add new tests to CI run. Signed-off-by: Ilya Enkovich * Fix CPU libdevice import. Signed-off-by: Ilya Enkovich * Fix asm parsing in libmvec test. Signed-off-by: Ilya Enkovich * Promote libm ops to FP32. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich Co-authored-by: Minjang Kim --- .github/workflows/build-test.yml | 4 +- python/setup.py | 1 + python/test/unit/cpu/test_libmvec.py | 96 +++++++ python/triton/language/extra/cpu/libdevice.py | 40 +-- third_party/cpu/backend/compiler.py | 10 +- third_party/cpu/backend/driver.py | 5 +- .../cpu/include/TritonCPUToLLVM/Passes.h | 1 + .../cpu/include/TritonCPUToLLVM/Passes.td | 13 + .../cpu/include/TritonCPUTransforms/Passes.h | 3 +- .../cpu/include/TritonCPUTransforms/Passes.td | 3 + .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 1 + .../cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp | 258 ++++++++++++++++++ .../ConvertUnsupportedOps.cpp | 63 ++++- third_party/cpu/triton_cpu.cc | 8 +- 14 files changed, 474 insertions(+), 32 deletions(-) create mode 100644 python/test/unit/cpu/test_libmvec.py create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 6967e13b8eb3..853196ebbba7 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -82,7 +82,9 @@ jobs: python -m pytest -s -n 32 --device cpu \ python/test/unit/language/test_annotations.py \ python/test/unit/language/test_block_pointer.py \ - python/test/unit/language/test_conversions.py + python/test/unit/language/test_conversions.py \ + python/test/unit/cpu/test_libdevice.py \ + python/test/unit/cpu/test_libmvec.py - name: Run lit tests run: | diff --git a/python/setup.py b/python/setup.py index ec32ea11a5b8..6fcea0acc354 100644 --- a/python/setup.py +++ b/python/setup.py @@ -689,6 +689,7 @@ def get_packages(): "triton/compiler", "triton/language", "triton/language/extra", + "triton/language/extra/cpu", "triton/runtime", "triton/backends", "triton/tools", diff --git a/python/test/unit/cpu/test_libmvec.py b/python/test/unit/cpu/test_libmvec.py new file mode 100644 index 000000000000..55dc7ec067e4 --- /dev/null +++ b/python/test/unit/cpu/test_libmvec.py @@ -0,0 +1,96 @@ +import os +import pytest +import torch + +import triton +import triton.language as tl +from triton.language.extra import libdevice + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + +def is_x86(): + return is_cpu() and \ + triton.runtime.driver.active.get_current_target().arch == "x86_64" + + +float_dtypes = ['bfloat16', 'float16', 'float32', 'float64'] + + +@pytest.mark.parametrize("dtype_str", float_dtypes) +@pytest.mark.parametrize("math_fn", ["cos", "exp", "exp2", "log", "log2", "sin"]) +@pytest.mark.parametrize("size", [1, 2, 4, 8, 16, 32, 64, 128]) +def test_tensor_math_fn(dtype_str, math_fn, size, device): + if not is_x86(): + pytest.skip("Vectorized libm calls are supported for x86 target only.") + + @triton.jit + def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + idxs = tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = getattr(x, MATH_FN)() + tl.store(dst + idxs, y) + + src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) + res = torch.empty(src.shape, dtype=getattr(torch, dtype_str), device=device) + meta = kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size) + ref = getattr(src, math_fn)() + torch.testing.assert_close(ref, res) + + # Check generated code calls vector math function + # FP16 and BF16 are casted to FP32 for math ops + elem_size = 8 if dtype_str == "float64" else 4 + data_size = size * elem_size + num_vec_calls = 0 + if data_size >= 16: + num_vec_calls = 1 + if data_size > 64: + num_vec_calls = data_size / 64 + assert meta.asm["asm"].count("_ZGV") == num_vec_calls + + +@pytest.mark.parametrize("dtype_str", float_dtypes) +@pytest.mark.parametrize("math_fn", [ + "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "log", "log2", + "log10", "sin", "sinh", "tan", "tanh" +]) +@pytest.mark.parametrize("size", [1, 2, 4, 8, 16, 32, 64, 128]) +def test_libdevice_math_fn(dtype_str, math_fn, size, device): + if not is_x86(): + pytest.skip("Vectorized libm calls are supported for x86 target only.") + + @triton.jit + def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + idxs = tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = getattr(libdevice, MATH_FN)(x) + tl.store(dst + idxs, y) + + src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) + if math_fn == "acosh": + src = src.abs() + 1 + res = torch.empty(src.shape, dtype=getattr(torch, dtype_str), device=device) + meta = kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size) + if math_fn == "cbrt": + ref = src.pow(1 / 3) + else: + ref = getattr(src, math_fn)() + torch.testing.assert_close(ref, res) + + # Check generated code calls vector math function + # FP16 and BF16 are casted to FP32 for math ops + elem_size = 8 if dtype_str == "float64" else 4 + data_size = size * elem_size + num_vec_calls = 0 + if data_size >= 16: + num_vec_calls = 1 + if data_size > 64: + num_vec_calls = data_size / 64 + assert meta.asm["asm"].count("_ZGV") == num_vec_calls diff --git a/python/triton/language/extra/cpu/libdevice.py b/python/triton/language/extra/cpu/libdevice.py index d7e8cdde3cfd..bc1926f4b893 100644 --- a/python/triton/language/extra/cpu/libdevice.py +++ b/python/triton/language/extra/cpu/libdevice.py @@ -1,96 +1,96 @@ -from triton.language import core, tensor +from triton.language import core @core.extern def acos(arg0, _builder=None): - return tensor(_builder.create_acos(arg0.handle), arg0.type) + return core.tensor(_builder.create_acos(arg0.handle), arg0.type) @core.extern def acosh(arg0, _builder=None): - return tensor(_builder.create_acosh(arg0.handle), arg0.type) + return core.tensor(_builder.create_acosh(arg0.handle), arg0.type) @core.extern def asin(arg0, _builder=None): - return tensor(_builder.create_asin(arg0.handle), arg0.type) + return core.tensor(_builder.create_asin(arg0.handle), arg0.type) @core.extern def asinh(arg0, _builder=None): - return tensor(_builder.create_asinh(arg0.handle), arg0.type) + return core.tensor(_builder.create_asinh(arg0.handle), arg0.type) @core.extern def atan(arg0, _builder=None): - return tensor(_builder.create_atan(arg0.handle), arg0.type) + return core.tensor(_builder.create_atan(arg0.handle), arg0.type) @core.extern def atanh(arg0, _builder=None): - return tensor(_builder.create_atanh(arg0.handle), arg0.type) + return core.tensor(_builder.create_atanh(arg0.handle), arg0.type) @core.extern def cbrt(arg0, _builder=None): - return tensor(_builder.create_cbrt(arg0.handle), arg0.type) + return core.tensor(_builder.create_cbrt(arg0.handle), arg0.type) @core.extern def cos(arg0, _builder=None): - return tensor(_builder.create_cos(arg0.handle), arg0.type) + return core.tensor(_builder.create_cos(arg0.handle), arg0.type) @core.extern def cosh(arg0, _builder=None): - return tensor(_builder.create_cosh(arg0.handle), arg0.type) + return core.tensor(_builder.create_cosh(arg0.handle), arg0.type) @core.extern def erf(arg0, _builder=None): - return tensor(_builder.create_erf(arg0.handle), arg0.type) + return core.tensor(_builder.create_erf(arg0.handle), arg0.type) @core.extern def exp(arg0, _builder=None): - return tensor(_builder.create_exp(arg0.handle), arg0.type) + return core.tensor(_builder.create_exp(arg0.handle), arg0.type) @core.extern def exp2(arg0, _builder=None): - return tensor(_builder.create_exp2(arg0.handle), arg0.type) + return core.tensor(_builder.create_exp2(arg0.handle), arg0.type) @core.extern def log(arg0, _builder=None): - return tensor(_builder.create_log(arg0.handle), arg0.type) + return core.tensor(_builder.create_log(arg0.handle), arg0.type) @core.extern def log2(arg0, _builder=None): - return tensor(_builder.create_log2(arg0.handle), arg0.type) + return core.tensor(_builder.create_log2(arg0.handle), arg0.type) @core.extern def log10(arg0, _builder=None): - return tensor(_builder.create_log10(arg0.handle), arg0.type) + return core.tensor(_builder.create_log10(arg0.handle), arg0.type) @core.extern def sin(arg0, _builder=None): - return tensor(_builder.create_sin(arg0.handle), arg0.type) + return core.tensor(_builder.create_sin(arg0.handle), arg0.type) @core.extern def sinh(arg0, _builder=None): - return tensor(_builder.create_sinh(arg0.handle), arg0.type) + return core.tensor(_builder.create_sinh(arg0.handle), arg0.type) @core.extern def tan(arg0, _builder=None): - return tensor(_builder.create_tan(arg0.handle), arg0.type) + return core.tensor(_builder.create_tan(arg0.handle), arg0.type) @core.extern def tanh(arg0, _builder=None): - return tensor(_builder.create_tanh(arg0.handle), arg0.type) + return core.tensor(_builder.create_tanh(arg0.handle), arg0.type) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 024e8f252ef6..335feee1366a 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -106,7 +106,10 @@ def make_tttcir(self, mod, metadata, opt): promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features # We don't have any lowering for mixed precision matmuls, so always use casts for now convert_mixed_precision_matmul = True - cpu.passes.ttcpuir.add_convert_unsupported_ops(pm, promote_bf16_to_fp32, convert_mixed_precision_matmul) + # We don't have math lib functions for FP8, FP16, BF16. Promote such operations to FP32. + promote_lib_math_to_fp32 = True + cpu.passes.ttcpuir.add_convert_unsupported_ops(pm, promote_bf16_to_fp32, convert_mixed_precision_matmul, + promote_lib_math_to_fp32) decompose_bf16_conv = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features decompose_fp8_conv = True cpu.passes.ttcpuir.add_decompose_fp_conversions(pm, decompose_bf16_conv, decompose_fp8_conv) @@ -116,8 +119,7 @@ def make_tttcir(self, mod, metadata, opt): pm.run(mod) return mod - @staticmethod - def make_llir(src, metadata, options): + def make_llir(self, src, metadata, options): # warp-specialization mutates num_warps num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta") if num_warp_groups is not None: @@ -133,6 +135,8 @@ def make_llir(src, metadata, options): passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) cpu.passes.ttcpuir.add_triton_cpu_to_llvmir_pipeline(pm) + if self.cpu_arch == "x86_64" and "avx512f" in self.cpu_features: + cpu.passes.ttcpuir.add_math_to_libmvec(pm) passes.convert.add_math_to_llvmir(pm) cpu.passes.ttcpuir.add_math_to_libm(pm) cpu.passes.ttcpuir.add_vector_to_llvmir(pm, options.enable_fast_math) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 08400174bc09..f17a0cba9f30 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -10,6 +10,8 @@ from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget +from triton._C.libtriton import llvm + _dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") # for locating libTritonCPURuntime _triton_C_dir = importlib.resources.files(triton).joinpath("_C") @@ -359,7 +361,8 @@ def get_current_stream(self, device): def get_current_target(self): # Capability and warp size are zeros for CPU. # TODO: GPUTarget naming isn't obviously good. - return GPUTarget("cpu", 0, 0) + cpu_arch = llvm.get_cpu_tripple().split("-")[0] + return GPUTarget("cpu", cpu_arch, 0) @staticmethod def is_active(): diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index 288eed5256b4..f00fc0560c3f 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -26,6 +26,7 @@ std::unique_ptr> createGetProgramIdOpToLLVMPass(); std::unique_ptr> createLowerMultiReductionPass(); std::unique_ptr> createAtomicOpsToLLVMPass(); std::unique_ptr> createDebugOpsToLLVMPass(); +std::unique_ptr> createMathToLibmvecPass(); void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); void registerTritonCPUToLLVMPipeline(); diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index 06a9114d7696..10942f807248 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -66,4 +66,17 @@ def DebugOpsToLLVM : Pass<"triton-cpu-debug-ops-to-llvm", "mlir::ModuleOp"> { "mlir::triton::TritonDialect"]; } +def MathToLibmvec : Pass<"triton-cpu-math-to-libmvec", "mlir::ModuleOp"> { + let summary = "Convert vector math operations to vector libm calls."; + let description = [{ + }]; + let constructor = "mlir::triton::cpu::createMathToLibmvecPass()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect", + "mlir::func::FuncDialect", + "mlir::LLVM::LLVMDialect"]; +} + #endif diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index 89810b7ce526..0fab5b2e17ca 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -21,7 +21,8 @@ namespace cpu { std::unique_ptr> createConvertUnsupportedOps(); std::unique_ptr> createConvertUnsupportedOps(bool promoteBf16ToFp32, - bool convertMixedPrecisionMatmul); + bool convertMixedPrecisionMatmul, + bool promoteLibMathToFp32); std::unique_ptr> createDecomposeFpConversions(); std::unique_ptr> createDecomposeFpConversions(bool decomposeBf16Conversions, diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index d2273873310d..a40c47ab0287 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -18,6 +18,9 @@ def ConvertUnsupportedOps : Pass<"triton-cpu-add-casts-for-unsupported-ops", "ml Option<"convertMixedPrecisionMatmul", "convert-mixed-precision-matmul", "bool", /*default*/"false", "Convert inputs of a mixed-precision matmul to a destination type.">, + Option<"promoteLibMathToFp32", "promote-lib-math-to-fp32", + "bool", /*default*/"true", + "Promote FP8, FP16, BF16 math operations mapped to libm function to FP32.">, ]; let constructor = "mlir::triton::cpu::createConvertUnsupportedOps()"; diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt index d469b7968682..f355b6d46d4e 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(TritonCPUToLLVM FuncOpToLLVM.cpp GetProgramIdOpToLLVM.cpp LowerMultiReduction.cpp + MathToLibmvec.cpp MemoryOpToLLVM.cpp Pipeline.cpp TypeConverter.cpp diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp new file mode 100644 index 000000000000..5a035e217fcf --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp @@ -0,0 +1,258 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_MATHTOLIBMVEC +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +template struct VecOpToFp32 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + VecOpToFp32(MLIRContext *context) : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + VectorType vecTy = dyn_cast(op.getType()); + if (!vecTy) + return failure(); + + Type elemTy = vecTy.getElementType(); + if (!elemTy.isBF16() && !elemTy.isF16()) + return failure(); + + Type fp32VecTy = vecTy.cloneWith(std::nullopt, rewriter.getF32Type()); + SmallVector fp32Ops; + for (auto operand : op->getOperands()) + fp32Ops.push_back( + rewriter.create(loc, fp32VecTy, operand)); + auto newOp = rewriter.create(loc, fp32VecTy, fp32Ops); + rewriter.replaceOpWithNewOp(op, vecTy, newOp); + return success(); + } +}; + +// Decompose vector operation to singe-dimensional vector operations +// with a native AVX512 vector size. +template +struct DecomposeToNativeVecs : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + DecomposeToNativeVecs(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + VectorType vecTy = dyn_cast(op.getType()); + if (!vecTy) + return failure(); + + Type elemTy = vecTy.getElementType(); + if (!elemTy.isF32() && !elemTy.isF64()) + return failure(); + + int64_t numElems = vecTy.getNumElements(); + if (numElems * elemTy.getIntOrFloatBitWidth() < 128) + return failure(); + + // Produce a new shape where trailing dimensions wouldn't exceed the native + // vector size. + auto shape = vecTy.getShape(); + SmallVector newShape(1, 1); + int64_t elemsPerVec = 512 / elemTy.getIntOrFloatBitWidth(); + for (int64_t i = shape.size() - 1; i >= 0; --i) { + int64_t size = shape[i]; + if (newShape.size() > 1) { + newShape.insert(newShape.begin(), size); + } else { + int64_t combined = newShape[0] * size; + if (combined > elemsPerVec) { + newShape[0] = elemsPerVec; + newShape.insert(newShape.begin(), combined / elemsPerVec); + } else { + newShape[0] = combined; + } + } + } + if (newShape == shape) + return failure(); + + // Convert input operand to the new shape. + SmallVector reshapedInputs; + for (auto operand : op->getOperands()) { + auto operandTy = cast(operand.getType()); + auto newOperandTy = VectorType::get(newShape, operandTy.getElementType()); + reshapedInputs.push_back( + rewriter.create(loc, newOperandTy, operand)); + } + + // Decompose the original operation to a set of operations on native + // vectors. + auto newOpTy = VectorType::get(newShape, elemTy); + auto subResTy = VectorType::get(newShape.back(), elemTy); + Value newRes = rewriter.create( + loc, SplatElementsAttr::get(newOpTy, rewriter.getFloatAttr(elemTy, 0))); + auto strides = computeStrides(newShape); + // Remove the last stride to produce sub-vector indices. + strides.pop_back(); + for (int64_t idx = 0; idx < numElems; idx += newShape.back()) { + auto indices = delinearize(idx, strides); + SmallVector subInputs(reshapedInputs.size()); + std::transform(reshapedInputs.begin(), reshapedInputs.end(), + subInputs.begin(), [&](auto val) { + return rewriter.create(loc, val, + indices); + }); + Value subRes = rewriter.create(loc, subResTy, subInputs); + newRes = rewriter.create(loc, subRes, newRes, indices); + } + + // Reshape the result back to the original type. + rewriter.replaceOpWithNewOp(op, vecTy, newRes); + return success(); + } +}; + +template +struct VecOpToLibmvecCall : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + VecOpToLibmvecCall(MLIRContext *context, StringRef fp32FnBaseName, + StringRef fp64FnBaseName) + : OpRewritePattern(context) { + this->fp32FnBaseName = fp32FnBaseName; + this->fp64FnBaseName = fp64FnBaseName; + } + + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { + VectorType vecTy = dyn_cast(op.getType()); + if (!vecTy || vecTy.getRank() > 1) + return failure(); + + Type elemTy = vecTy.getElementType(); + if (!elemTy.isF32() && !elemTy.isF64()) + return failure(); + + auto baseName = elemTy.isF32() ? fp32FnBaseName : fp64FnBaseName; + int64_t vecSize = vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); + std::string isaPrefix; + if (vecSize == 128) { + isaPrefix = "b"; + } else if (vecSize == 256) { + isaPrefix = "d"; + } else if (vecSize == 512) { + isaPrefix = "e"; + } else { + return failure(); + } + std::string fnName = + "_ZGV" + isaPrefix + "N" + std::to_string(vecTy.getNumElements()); + for (auto operand : op->getOperands()) + fnName += "v"; + fnName += "_" + baseName; + + auto module = SymbolTable::getNearestSymbolTable(op); + auto opFunc = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, fnName)); + // Generate function declaration if it doesn't exists yet. + if (!opFunc) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module->getRegion(0).front()); + auto fnTy = FunctionType::get( + rewriter.getContext(), op->getOperandTypes(), op->getResultTypes()); + opFunc = + rewriter.create(rewriter.getUnknownLoc(), fnName, fnTy); + opFunc.setPrivate(); + opFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(), + UnitAttr::get(rewriter.getContext())); + } + + rewriter.replaceOpWithNewOp(op, fnName, op.getType(), + op->getOperands()); + return success(); + } + +private: + std::string fp32FnBaseName; + std::string fp64FnBaseName; +}; + +template +void populatePatternsForOp(RewritePatternSet &patterns, StringRef fp32FnName, + StringRef fp64FnName) { + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext(), fp32FnName, + fp64FnName); +} + +struct MathToLibmvecPass + : public mlir::triton::cpu::impl::MathToLibmvecBase { + using MathToLibmvecBase::MathToLibmvecBase; + + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + + RewritePatternSet patterns(context); + populatePatternsForOp(patterns, "acosf", "acos"); + populatePatternsForOp(patterns, "acoshf", "acosh"); + populatePatternsForOp(patterns, "asinf", "asin"); + populatePatternsForOp(patterns, "asinhf", "asinh"); + populatePatternsForOp(patterns, "atanf", "atan"); + populatePatternsForOp(patterns, "atanhf", "atanh"); + populatePatternsForOp(patterns, "cbrtf", "cbrt"); + populatePatternsForOp(patterns, "cosf", "cos"); + populatePatternsForOp(patterns, "coshf", "cosh"); + populatePatternsForOp(patterns, "erff", "erf"); + populatePatternsForOp(patterns, "expf", "exp"); + populatePatternsForOp(patterns, "exp2f", "exp2"); + populatePatternsForOp(patterns, "logf", "log"); + populatePatternsForOp(patterns, "log2f", "log2"); + populatePatternsForOp(patterns, "log10f", "log10"); + populatePatternsForOp(patterns, "log1pf", "log1p"); + populatePatternsForOp(patterns, "sinf", "sin"); + populatePatternsForOp(patterns, "sinhf", "sinh"); + populatePatternsForOp(patterns, "tanf", "tan"); + populatePatternsForOp(patterns, "tanhf", "tanh"); + + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // anonymous namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createMathToLibmvecPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index d568a7bd5bb2..ec51257caf67 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -301,15 +301,40 @@ struct ConvertMixedPrecisionMatmul } }; +template struct PromoteOpToFp32 : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PromoteOpToFp32(MLIRContext *context) : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Type opTy = op.getType(); + + if (!isFp8(opTy) && !isFp16(opTy) && !isBf16(opTy)) + return failure(); + + Type fp32Ty = toFp32(opTy); + SmallVector fp32Ops; + for (auto operand : op->getOperands()) + fp32Ops.push_back(rewriter.create(loc, fp32Ty, operand)); + auto newOp = rewriter.create(loc, fp32Ty, fp32Ops); + rewriter.replaceOpWithNewOp(op, opTy, newOp); + return success(); + } +}; + struct ConvertUnsupportedOps : public triton::cpu::impl::ConvertUnsupportedOpsBase< ConvertUnsupportedOps> { ConvertUnsupportedOps() = default; ConvertUnsupportedOps(bool promoteBf16ToFp32, - bool convertMixedPrecisionMatmul) { + bool convertMixedPrecisionMatmul, + bool promoteLibMathToFp32) { this->promoteBf16ToFp32 = promoteBf16ToFp32; this->convertMixedPrecisionMatmul = convertMixedPrecisionMatmul; + this->promoteLibMathToFp32 = promoteLibMathToFp32; } void runOnOperation() override { @@ -333,6 +358,35 @@ struct ConvertUnsupportedOps if (convertMixedPrecisionMatmul) { patterns.add(context); } + if (promoteLibMathToFp32) { + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + patterns.add>(context); + } if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) return signalPassFailure(); @@ -351,9 +405,10 @@ std::unique_ptr> createConvertUnsupportedOps() { std::unique_ptr> createConvertUnsupportedOps(bool promoteBf16ToFp32, - bool convertMixedPrecisionMatmul) { - return std::make_unique(promoteBf16ToFp32, - convertMixedPrecisionMatmul); + bool convertMixedPrecisionMatmul, + bool promoteLibMathToFp32) { + return std::make_unique( + promoteBf16ToFp32, convertMixedPrecisionMatmul, promoteLibMathToFp32); } } // namespace cpu diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 3959bf28f4e1..82c3c92ce89c 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -32,9 +32,10 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { }); m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm, bool promote_bf16_to_fp32, - bool convert_mixed_precision_matmul) { + bool convert_mixed_precision_matmul, bool promote_lib_math_to_fp32) { pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps( - promote_bf16_to_fp32, convert_mixed_precision_matmul)); + promote_bf16_to_fp32, convert_mixed_precision_matmul, + promote_lib_math_to_fp32)); }); m.def("add_decompose_fp_conversions", [](mlir::PassManager &pm, bool decomposeBf16Conversions, @@ -71,6 +72,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); }); + m.def("add_math_to_libmvec", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createMathToLibmvecPass()); + }); m.def("add_math_to_libm", [](mlir::PassManager &pm) { pm.addPass(mlir::createConvertMathToLibmPass()); }); From e35936deb6d8265d0d50ecf1690d4e0fa40ac972 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 6 Aug 2024 11:35:51 -0400 Subject: [PATCH 068/165] Make tl.debug_barrier() a no-op on CPU (#89) --- .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index e6b6a531059c..2bad397c9b77 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -3,6 +3,8 @@ #include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "mlir/Dialect/GPU/IR/GPUOps.h.inc" + #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" @@ -164,6 +166,23 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { } }; +using BarrierOp = mlir::gpu::BarrierOp; + +// This is part of the DebugOps pass because gpu::barrier is generated by +// tl.debug_barrier. +struct BarrierOpConversion : public ConvertOpToLLVMPattern { + explicit BarrierOpConversion(LLVMTypeConverter &typeConverter) + : mlir::ConvertOpToLLVMPattern(typeConverter) {} + + LogicalResult + matchAndRewrite(BarrierOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Just make it a no-op for now + rewriter.eraseOp(op); + return success(); + } +}; + struct DebugOpsToLLVM : public triton::impl::DebugOpsToLLVMBase { using DebugOpsToLLVMBase::DebugOpsToLLVMBase; @@ -180,6 +199,7 @@ struct DebugOpsToLLVM RewritePatternSet patterns(context); patterns.add(typeConverter); + patterns.add(typeConverter); // patterns.add(typeConverter); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { From ee88d7e1925f9541d70e7b7fee75b01938f528b8 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 6 Aug 2024 13:20:13 -0400 Subject: [PATCH 069/165] ConvertMemoryOps should not use cf dialect (#91) Doing so is likely to break invariants that the scf dialect expects. In particular, when handling a masked load or store within a scf.for loop, ConvertMemoryOps would create multiple blocks within that loop, resulting in errors of the form 'scf.for' op expects region #0 to have 0 or 1 blocks I have added a test to exercise the `lowerToScalar{Loads,Stores}` codepaths. I have not included the for loop in those tests because I have already removed "cf" as a legal dialect for the ConvertMemoryOps pass, which should prevent future errors. --- test/TritonCPU/convert-masked.mlir | 71 ++++++++++++++ .../TritonToTritonCPU/ConvertMemoryOps.cpp | 94 ++++++++----------- 2 files changed, 111 insertions(+), 54 deletions(-) create mode 100644 test/TritonCPU/convert-masked.mlir diff --git a/test/TritonCPU/convert-masked.mlir b/test/TritonCPU/convert-masked.mlir new file mode 100644 index 000000000000..c2ec0fa4f742 --- /dev/null +++ b/test/TritonCPU/convert-masked.mlir @@ -0,0 +1,71 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops | FileCheck %s + +// Convert strided masked loads to scalar loads. + +// CHECK-LABEL: @strided_masked_loads +// CHECK: %[[COND:.+]] = vector.extract %[[MASK:.+]][[[#IDX:]]] : i1 +// CHECK-NEXT: scf.if %[[COND]] -> (vector<32xi32>) { +// CHECK-NEXT: %[[PTR:.+]] = vector.extract %[[IN:.+]][[[#IDX]]] : i64 from vector<32xi64> +// CHECK-NEXT: %[[PTR_:.+]] = tt.int_to_ptr %[[PTR]] : i64 -> !tt.ptr +// CHECK-NEXT: %[[VAL:.+]] = tt.load %[[PTR_]] : !tt.ptr +// CHECK-NEXT: %[[NEW_OUT:.+]] = vector.insert %[[VAL]], %[[OUT:.+]] [[[#IDX]]] : i32 into vector<32xi32> +// CHECK-NEXT: scf.yield %[[NEW_OUT]] : vector<32xi32> +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[OUT]] : vector<32xi32> +// CHECK-NEXT: } + +module { + tt.func public @strided_masked_loads(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %c1_i32 = arith.constant 1 : i32 + %c10_i32 = arith.constant 10 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<2> : tensor<32xi32> + %cst_0 = arith.constant dense<16> : tensor<32xi32> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = arith.cmpi slt, %0, %cst_0 : tensor<32xi32> + %2 = arith.muli %0, %cst : tensor<32xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %4 = tt.addptr %3, %2 : tensor<32x!tt.ptr>, tensor<32xi32> + scf.for %arg1 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { + %5 = tt.load %4, %1 : tensor<32x!tt.ptr> + tt.store %4, %5 : tensor<32x!tt.ptr> + } + tt.return + } +} + +// ----- + +// Convert strided masked stores to scalar stores. + +// CHECK-LABEL: @strided_masked_stores +// CHECK: %[[COND:.+]] = vector.extract %[[MASK:.+]][[[#IDX:]]] : i1 from vector<32xi1> +// CHECK-NEXT: scf.if %[[COND]] { +// CHECK-NEXT: %[[PTR:.+]] = vector.extract %[[OUT:.+]][[[#IDX]]] : i64 from vector<32xi64> +// CHECK-NEXT: %[[PTR_:.+]] = tt.int_to_ptr %[[PTR]] : i64 -> !tt.ptr +// CHECK-NEXT: %[[VAL:.+]] = vector.extract %[[IN:.+]][[[#IDX]]] : i32 from vector<32xi32> +// CHECK-NEXT: tt.store %[[PTR_]], %[[VAL]] : !tt.ptr +// CHECK-NEXT: } + +module { + tt.func public @strided_masked_stores(%arg0: !tt.ptr {tt.divisibility = 16 : i32} ) { + %c1_i32 = arith.constant 1 : i32 + %c10_i32 = arith.constant 10 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<64> : tensor<32xi32> + %cst_0 = arith.constant dense<2> : tensor<32xi32> + %cst_1 = arith.constant dense<16> : tensor<32xi32> + %0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32> + %1 = arith.cmpi slt, %0, %cst_1 : tensor<32xi32> + %2 = arith.muli %0, %cst_0 : tensor<32xi32> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<32x!tt.ptr> + %4 = tt.addptr %3, %2 : tensor<32x!tt.ptr>, tensor<32xi32> + %5 = arith.subi %cst, %2 : tensor<32xi32> + %6 = tt.addptr %3, %5 : tensor<32x!tt.ptr>, tensor<32xi32> + scf.for %arg1 = %c0_i32 to %c10_i32 step %c1_i32 : i32 { + %7 = tt.load %4 : tensor<32x!tt.ptr> + tt.store %6, %7, %1 : tensor<32x!tt.ptr> + } + tt.return + } +} diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 8896ab5ebd8d..cada363a9381 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -233,41 +233,36 @@ struct LoadOpConversion : public MemoryOpConversion { auto cache = loadOp.getCache(); auto evict = loadOp.getEvict(); auto isVolatile = loadOp.getIsVolatile(); - Value dst = convertOtherVal(loadOp, rewriter); - int64_t numElems = vecTy.getNumElements(); - auto strides = computeStrides(vecTy.getShape()); - for (auto idx = 0; idx < numElems; ++idx) { - auto indices = delinearize(idx, strides); - Block *headerBlock = rewriter.getBlock(); - Block *condBlock = nullptr; - Value origDst = dst; - // Create a conditional block for load if there is a mask. - if (mask) { - condBlock = - rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToStart(condBlock); - } + auto loadOne = [=, &rewriter](ArrayRef indices, Value dst) { Value ptr = rewriter.create(loc, ptrs, indices); ptr = rewriter.create(loc, ptrTy, ptr); Value val = rewriter.create(loc, ptr, cache, evict, isVolatile); - dst = rewriter.create(loc, val, dst, indices); - - // Add predicate and branches. - if (mask) { - Block *footerBlock = - rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); - Value resDst = dst; - dst = footerBlock->addArgument(dst.getType(), dst.getLoc()); - rewriter.setInsertionPointToEnd(headerBlock); - auto predicate = rewriter.create(loc, mask, indices); - rewriter.create(loc, predicate, condBlock, - footerBlock, origDst); - rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, footerBlock, resDst); - rewriter.setInsertionPointToStart(footerBlock); + return rewriter.create(loc, val, dst, indices); + }; + + Value dst = convertOtherVal(loadOp, rewriter); + int64_t numElems = vecTy.getNumElements(); + auto strides = computeStrides(vecTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + if (!mask) { + dst = loadOne(indices, dst); + continue; } + // Create a conditional block for load if there is a mask. + auto predicate = rewriter.create(loc, mask, indices); + auto ifOp = rewriter.create( + loc, predicate, + [&](OpBuilder &builder, Location loc) { + auto result = loadOne(indices, dst).getResult(); + rewriter.create(loc, result); + }, + [&](OpBuilder &builder, Location loc) { + rewriter.create(loc, dst); + }); + dst = ifOp.getResult(0); } rewriter.replaceOp(loadOp, dst); @@ -381,36 +376,28 @@ struct StoreOpConversion : public MemoryOpConversion { auto cache = storeOp.getCache(); auto evict = storeOp.getEvict(); - int64_t numElems = tensorTy.getNumElements(); - auto strides = computeStrides(tensorTy.getShape()); - for (auto idx = 0; idx < numElems; ++idx) { - auto indices = delinearize(idx, strides); - Block *headerBlock = rewriter.getBlock(); - Block *condBlock = nullptr; - // Create a conditional block for store if there is a mask. - if (mask) { - condBlock = - rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToStart(condBlock); - } - + auto storeOne = [=, &rewriter](ArrayRef indices) { Value ptr = rewriter.create(loc, ptrs, indices); ptr = rewriter.create(loc, ptrTy, ptr); Value val = rewriter.create(loc, vals, indices); rewriter.create(loc, ptr, val, cache, evict); + }; - // Add predicate and branches. - if (mask) { - Block *footerBlock = - rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToEnd(headerBlock); - auto predicate = rewriter.create(loc, mask, indices); - rewriter.create(loc, predicate, condBlock, - footerBlock); - rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, footerBlock); - rewriter.setInsertionPointToStart(footerBlock); + int64_t numElems = tensorTy.getNumElements(); + auto strides = computeStrides(tensorTy.getShape()); + for (auto idx = 0; idx < numElems; ++idx) { + auto indices = delinearize(idx, strides); + if (!mask) { + storeOne(indices); + continue; } + // Create a conditional block for store if there is a mask. + auto predicate = rewriter.create(loc, mask, indices); + rewriter.create(loc, predicate, + [&](OpBuilder &builder, Location loc) { + storeOne(indices); + rewriter.create(loc); + }); } rewriter.eraseOp(storeOp); @@ -425,7 +412,6 @@ class MemoryOpConversionTarget : public ConversionTarget { addLegalDialect(); addLegalDialect(); addLegalDialect(); - addLegalDialect(); addLegalDialect(); addLegalDialect(); addLegalOp(); From a95b8ebead45a2f8b29bbc52e37eda872d17c974 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 6 Aug 2024 19:09:11 -0500 Subject: [PATCH 070/165] Remove registered pipelines in favor of explicit lists in python. (#93) Signed-off-by: Ilya Enkovich --- bin/RegisterTritonDialects.h | 2 - third_party/cpu/backend/compiler.py | 17 ++++++- .../cpu/include/TritonCPUToLLVM/Passes.h | 3 -- .../cpu/include/TritonToTritonCPU/Passes.h | 3 -- .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 1 - .../cpu/lib/TritonCPUToLLVM/Pipeline.cpp | 26 ---------- .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 1 - .../cpu/lib/TritonToTritonCPU/Pipeline.cpp | 32 ------------- third_party/cpu/triton_cpu.cc | 47 +++++++++++++++++-- 9 files changed, 58 insertions(+), 74 deletions(-) delete mode 100644 third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp delete mode 100644 third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index c67987e8159e..41dc478fd7ce 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -74,10 +74,8 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { // CPU passes mlir::triton::cpu::registerTritonToTritonCPUPasses(); - mlir::triton::cpu::registerTritonToTritonCPUPipeline(); mlir::triton::cpu::registerTritonCPUTransformsPasses(); mlir::triton::cpu::registerTritonCPUToLLVMPasses(); - mlir::triton::cpu::registerTritonCPUToLLVMPipeline(); // TODO: register Triton & TritonGPU passes registry.insert TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - cpu.passes.ttcpuir.add_triton_to_triton_cpu_pipeline(pm) + cpu.passes.ttcpuir.add_convert_memory_ops(pm) + cpu.passes.ttcpuir.add_convert_ptr_ops(pm) + cpu.passes.ttcpuir.add_convert_elementwise_ops(pm) + cpu.passes.ttcpuir.add_convert_elem_manip_ops(pm) + cpu.passes.ttcpuir.add_convert_dot_op(pm) + cpu.passes.ttcpuir.add_convert_histogram_op(pm) + cpu.passes.ttcpuir.add_convert_reduction_op(pm) + cpu.passes.ttcpuir.add_convert_scan_op(pm) + cpu.passes.ttcpuir.add_convert_cf_ops(pm) + cpu.passes.ttcpuir.add_convert_atomic_ops(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) passes.common.add_canonicalizer(pm) @@ -134,7 +143,11 @@ def make_llir(self, src, metadata, options): cpu.passes.ttcpuir.add_lower_affine(pm) passes.convert.add_scf_to_cf(pm) passes.convert.add_index_to_llvmir(pm) - cpu.passes.ttcpuir.add_triton_cpu_to_llvmir_pipeline(pm) + cpu.passes.ttcpuir.add_func_op_to_llvmir(pm) + cpu.passes.ttcpuir.add_program_id_to_llvmir(pm) + cpu.passes.ttcpuir.add_memory_op_to_llvmir(pm) + cpu.passes.ttcpuir.add_atomic_ops_to_llvmir(pm) + cpu.passes.ttcpuir.add_debug_ops_to_llvmir(pm) if self.cpu_arch == "x86_64" and "avx512f" in self.cpu_features: cpu.passes.ttcpuir.add_math_to_libmvec(pm) passes.convert.add_math_to_llvmir(pm) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index f00fc0560c3f..556ac14bb669 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -28,9 +28,6 @@ std::unique_ptr> createAtomicOpsToLLVMPass(); std::unique_ptr> createDebugOpsToLLVMPass(); std::unique_ptr> createMathToLibmvecPass(); -void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm); -void registerTritonCPUToLLVMPipeline(); - #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUToLLVM/Passes.h.inc" diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 14df893f0bac..303b99ce3c43 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -29,9 +29,6 @@ std::unique_ptr> createConvertReductionOp(); std::unique_ptr> createConvertScanOp(); std::unique_ptr> createConvertAtomicOps(); -void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); -void registerTritonToTritonCPUPipeline(); - #define GEN_PASS_REGISTRATION #include "cpu/include/TritonToTritonCPU/Passes.h.inc" diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt index f355b6d46d4e..b4c6372132be 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -6,7 +6,6 @@ add_triton_library(TritonCPUToLLVM LowerMultiReduction.cpp MathToLibmvec.cpp MemoryOpToLLVM.cpp - Pipeline.cpp TypeConverter.cpp Utility.cpp diff --git a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp b/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp deleted file mode 100644 index 8c02cc944b75..000000000000 --- a/third_party/cpu/lib/TritonCPUToLLVM/Pipeline.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Conversion/Passes.h" -#include "mlir/Pass/PassManager.h" - -namespace mlir { -namespace triton { -namespace cpu { - -void tritonCPUToLLVMPipelineBuilder(OpPassManager &pm) { - pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); - pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); - pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); - pm.addPass(mlir::triton::cpu::createAtomicOpsToLLVMPass()); - pm.addPass(mlir::triton::cpu::createDebugOpsToLLVMPass()); -} - -void registerTritonCPUToLLVMPipeline() { - PassPipelineRegistration<>("triton-cpu-to-llvmir", - "TritonCPU to LLVM conversion pipeline.", - tritonCPUToLLVMPipelineBuilder); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index dc34c5bd0199..b200a47da92d 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -9,7 +9,6 @@ add_triton_library(TritonToTritonCPU ConvertPtrOps.cpp ConvertReductionOp.cpp ConvertScanOp.cpp - Pipeline.cpp TypeConverter.cpp DEPENDS diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp deleted file mode 100644 index c7e7de72eecf..000000000000 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "cpu/include/TritonToTritonCPU/Passes.h" - -#include "mlir/Conversion/Passes.h" -#include "mlir/Pass/PassManager.h" - -namespace mlir { -namespace triton { -namespace cpu { - -void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); - pm.addPass(mlir::triton::cpu::createConvertPtrOps()); - pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); - pm.addPass(mlir::triton::cpu::createConvertElemManipOps()); - pm.addPass(mlir::triton::cpu::createConvertDotOp()); - pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); - pm.addPass(mlir::triton::cpu::createConvertReductionOp()); - pm.addPass(mlir::triton::cpu::createConvertScanOp()); - pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); - pm.addPass(mlir::triton::cpu::createConvertAtomicOps()); - // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); -} - -void registerTritonToTritonCPUPipeline() { - PassPipelineRegistration<>("triton-to-triton-cpu", - "Triton to TritonCPU conversion pipeline.", - tritonToTritonCPUPipelineBuilder); -} - -} // namespace cpu -} // namespace triton -} // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 82c3c92ce89c..8a2b7f1642ff 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -24,11 +24,35 @@ namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { using namespace mlir::triton; - m.def("add_triton_to_triton_cpu_pipeline", [](mlir::PassManager &pm) { - mlir::triton::cpu::tritonToTritonCPUPipelineBuilder(pm); + m.def("add_convert_memory_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); }); - m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) { - mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm); + m.def("add_convert_ptr_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertPtrOps()); + }); + m.def("add_convert_elementwise_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); + }); + m.def("add_convert_elem_manip_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertElemManipOps()); + }); + m.def("add_convert_dot_op", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDotOp()); + }); + m.def("add_convert_histogram_op", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); + }); + m.def("add_convert_reduction_op", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertReductionOp()); + }); + m.def("add_convert_scan_op", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertScanOp()); + }); + m.def("add_convert_cf_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); + }); + m.def("add_convert_atomic_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertAtomicOps()); }); m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm, bool promote_bf16_to_fp32, @@ -55,6 +79,21 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { pm.addNestedPass( mlir::triton::cpu::createLowerMultiReductionPass()); }); + m.def("add_func_op_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createFuncOpToLLVMPass()); + }); + m.def("add_program_id_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createGetProgramIdOpToLLVMPass()); + }); + m.def("add_memory_op_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createMemoryOpToLLVMPass()); + }); + m.def("add_atomic_ops_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createAtomicOpsToLLVMPass()); + }); + m.def("add_debug_ops_to_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createDebugOpsToLLVMPass()); + }); m.def("add_vector_to_llvmir", [](mlir::PassManager &pm, bool reassoc_fp_reduction) { mlir::ConvertVectorToLLVMPassOptions opts; From dcc69d2b6ae1108f4eddc984f522bb59270aaa88 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Tue, 6 Aug 2024 22:53:17 -0400 Subject: [PATCH 071/165] Don't use cf dialect in ConvertAtomicOps (#94) Similar to https://github.com/triton-lang/triton-cpu/pull/91 I also renamed convert-masked.mlir to convert-memory-ops.mlir because each test file should test one single pass, rather than one "feature" like masking. --- test/TritonCPU/convert-atomic.mlir | 36 +++++++++++++++++++ ...rt-masked.mlir => convert-memory-ops.mlir} | 0 .../TritonToTritonCPU/ConvertAtomicOps.cpp | 32 +++++++---------- 3 files changed, 49 insertions(+), 19 deletions(-) create mode 100644 test/TritonCPU/convert-atomic.mlir rename test/TritonCPU/{convert-masked.mlir => convert-memory-ops.mlir} (100%) diff --git a/test/TritonCPU/convert-atomic.mlir b/test/TritonCPU/convert-atomic.mlir new file mode 100644 index 000000000000..b0cad10f0d8f --- /dev/null +++ b/test/TritonCPU/convert-atomic.mlir @@ -0,0 +1,36 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-atomic-ops | FileCheck %s + +// Convert atomic ops with non-constant masks into scf.if + maskless atomic op. +// Check that the final tt.atomic_rmw only has 5 parameters (the 6th would be the mask). + +// CHECK-LABEL: @atomic_mask +// CHECK: %[[COND:.+]] = vector.extract %{{.+}}[[[#IDX:]]] : i1 from vector<16xi1> +// CHECK-NEXT: scf.if %[[COND]] -> (f32) { +// CHECK-NEXT: %[[OLD:.+]] = tt.atomic_rmw fadd, acq_rel, gpu, %{{[^%]+}} %{{[^%]+}} : (!tt.ptr, f32) -> f32 +// CHECK-NEXT: scf.yield %[[OLD]] : f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: scf.yield %[[CST]] : f32 +// CHECK-NEXT: } + +module { + tt.func public @atomic_mask(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60]> : vector<16xi64> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<5.000000e-01> : vector<16xf32> + %cst_1 = arith.constant dense<3.000000e+00> : vector<16xf32> + %0 = builtin.unrealized_conversion_cast %cst_1 : vector<16xf32> to tensor<16xf32> + %1 = tt.ptr_to_int %arg0 : !tt.ptr -> i64 + %2 = vector.splat %1 : vector<16xi64> + %3 = arith.addi %2, %cst : vector<16xi64> + %4 = builtin.unrealized_conversion_cast %3 : vector<16xi64> to tensor<16x!tt.ptr> + %5 = vector.extract %3[0] : i64 from vector<16xi64> + %6 = tt.int_to_ptr %5 : i64 -> !tt.ptr + %7 = triton_cpu.ptr_to_memref %6 : -> memref<16xf32> + %8 = vector.load %7[%c0] : memref<16xf32>, vector<16xf32> + %9 = arith.cmpf olt, %8, %cst_0 : vector<16xf32> + %10 = builtin.unrealized_conversion_cast %9 : vector<16xi1> to tensor<16xi1> + %11 = tt.atomic_rmw fadd, acq_rel, gpu, %4, %0, %10 : (tensor<16x!tt.ptr>, tensor<16xf32>, tensor<16xi1>) -> tensor<16xf32> + tt.return + } +} diff --git a/test/TritonCPU/convert-masked.mlir b/test/TritonCPU/convert-memory-ops.mlir similarity index 100% rename from test/TritonCPU/convert-masked.mlir rename to test/TritonCPU/convert-memory-ops.mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp index 61d3ac65e2fc..473e97ec8d77 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp @@ -30,7 +30,6 @@ class AtomicConversionTarget : public ConversionTarget { : ConversionTarget(ctx) { addLegalDialect(); addLegalDialect(); - addLegalDialect(); addLegalDialect(); addLegalDialect(); addLegalDialect(); @@ -121,24 +120,19 @@ struct AtomicRMWOpConversion : public OpConversionPattern { } } - Block *headerBlock = rewriter.getBlock(); - Value zero = rewriter.create( - loc, rewriter.getZeroAttr(val.getType())); - Block *condBlock = - rewriter.splitBlock(headerBlock, rewriter.getInsertionPoint()); - rewriter.setInsertionPointToStart(condBlock); - Value resVal = rewriter.create( - loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); - Block *footerBlock = - rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); - Value res = footerBlock->addArgument(resVal.getType(), resVal.getLoc()); - rewriter.setInsertionPointToEnd(headerBlock); - rewriter.create(loc, mask, condBlock, footerBlock, zero); - rewriter.setInsertionPointToEnd(condBlock); - rewriter.create(loc, footerBlock, resVal); - rewriter.setInsertionPointToStart(footerBlock); - - return res; + auto ifOp = rewriter.create( + loc, mask, + [&](OpBuilder &builder, Location loc) { + Value resVal = rewriter.create( + loc, val.getType(), rmwOp, ptr, val, nullptr, sem, scope); + rewriter.create(loc, resVal); + }, + [&](OpBuilder &builder, Location loc) { + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(val.getType())); + rewriter.create(loc, zero); + }); + return ifOp.getResult(0); } arith::ConstantOp getConstMaskDef(Value mask) const { From 9aa875725b29e2d7d9ff2b89146780321addef08 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Wed, 7 Aug 2024 10:21:44 -0400 Subject: [PATCH 072/165] atomic_rmw ops should return original value (#95) We were previously discarding the original value & returning all zeros. --- third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp index 473e97ec8d77..bab0cd94c57e 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp @@ -72,7 +72,7 @@ struct AtomicRMWOpConversion : public OpConversionPattern { auto ptrTy = cast(op.getPtr().getType()).getElementType(); auto vecTy = cast(vals.getType()); auto strides = computeStrides(vecTy.getShape()); - auto res = + Value res = rewriter.create(loc, rewriter.getZeroAttr(vecTy)); int64_t numElems = vecTy.getNumElements(); for (int64_t idx = 0; idx < numElems; ++idx) { @@ -97,7 +97,7 @@ struct AtomicRMWOpConversion : public OpConversionPattern { // Elements with const false mask are skipped. if (resElem) { - rewriter.create(loc, resElem, res, indices); + res = rewriter.create(loc, resElem, res, indices); } } From 9d4200e1aea9054b1ed34a3774021ee43692e69a Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 7 Aug 2024 13:24:42 -0500 Subject: [PATCH 073/165] Compute a scalar pointer for vector load instead of extracting it from a tensor (#92) * Compute a scalar pointer for vector load instead of extracting it from a tensor. Signed-off-by: Ilya Enkovich * Add lit test for scalar ptr usage. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- .github/workflows/build-test.yml | 3 +- python/test/unit/cpu/test_opt.py | 37 ++++++++++++ test/TritonCPU/convert-memory-ops.mlir | 23 ++++++++ .../TritonToTritonCPU/ConvertMemoryOps.cpp | 58 ++++++++++++++++++- 4 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 python/test/unit/cpu/test_opt.py diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 853196ebbba7..04417f6c5485 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -84,7 +84,8 @@ jobs: python/test/unit/language/test_block_pointer.py \ python/test/unit/language/test_conversions.py \ python/test/unit/cpu/test_libdevice.py \ - python/test/unit/cpu/test_libmvec.py + python/test/unit/cpu/test_libmvec.py \ + python/test/unit/cpu/test_opt.py - name: Run lit tests run: | diff --git a/python/test/unit/cpu/test_opt.py b/python/test/unit/cpu/test_opt.py new file mode 100644 index 000000000000..4051bfdd5bbe --- /dev/null +++ b/python/test/unit/cpu/test_opt.py @@ -0,0 +1,37 @@ +import os +import torch + +import triton +import triton.language as tl + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + +def is_x86(): + return is_cpu() and \ + triton.runtime.driver.active.get_current_target().arch == "x86_64" + + +def test_scalar_pointer_arith(device): + + @triton.jit + def kernel(src, dst, BLOCK_SIZE: tl.constexpr): + offs = tl.arange(0, BLOCK_SIZE) + x = tl.load(src + offs) + tl.store(dst + offs, x) + + src = torch.rand((128, ), dtype=torch.float32, device=device) + res = torch.empty_like(src) + meta = kernel[(1, )](src, res, BLOCK_SIZE=128) + assert (src == res).all() + + # Check TTCIR doesn't have pointer extraction from a tensor. + ttcir = meta.asm["ttcir"] + assert ttcir.count("extract") == 0 diff --git a/test/TritonCPU/convert-memory-ops.mlir b/test/TritonCPU/convert-memory-ops.mlir index c2ec0fa4f742..c98747269fdc 100644 --- a/test/TritonCPU/convert-memory-ops.mlir +++ b/test/TritonCPU/convert-memory-ops.mlir @@ -69,3 +69,26 @@ module { tt.return } } + +// ----- + +// Check that pointer for vector load/store is not extracted from a vector + +// CHECK-LABEL: @scalar_ptrs +// CHECK-NOT: vector.extract {{.+}} : i64 from vector<128xi64> +// CHECK: {{.+}} = vector.load {{.+}} : memref<128xf32>, vector<128xf32> +// CHECK-NOT: vector.extract {{.+}} : i64 from vector<128xi64> +// CHECK: vector.store {{.+}}, {{.+}} : memref<128xf32>, vector<128xf32> + +module { + tt.func public @scalar_ptrs(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.load %2 : tensor<128x!tt.ptr> + %4 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %5, %3 : tensor<128x!tt.ptr> + tt.return + } +} diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index cada363a9381..e872aa63fdff 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -48,7 +48,14 @@ struct MemoryOpConversion : public OpConversionPattern { Value extractScalarPointer(Location loc, Value ptrs, ArrayRef indices, ConversionPatternRewriter &rewriter) const { - // TODO: Analyze data flow and build scalar pointer computation code. + // If we build a vector of pointers and the extract a pointer from it, then + // compiler doesn't always optimize it to a simple scalar pointer + // computation. Here we try to follow a data flow of the tensor to rebuild a + // scalar pointer for more efficient resulting code. + if (canComputeScalarValue(ptrs)) + return computeScalarValue(ptrs, indices, rewriter); + + // Fall back to a scalar pointer extraction from the vector. Value ptr = rewriter.create( loc, rewriter.getRemappedValue(ptrs), indices); auto ptrTy = dyn_cast(ptrs.getType()).getElementType(); @@ -56,6 +63,55 @@ struct MemoryOpConversion : public OpConversionPattern { return ptr; } + bool canComputeScalarValue(Value vals) const { + if (auto def = vals.getDefiningOp()) { + return canComputeScalarValue(def.getPtr()) && + canComputeScalarValue(def.getOffset()); + } + + if (auto def = vals.getDefiningOp()) { + return canComputeScalarValue(def.getLhs()) && + canComputeScalarValue(def.getRhs()); + } + + if (vals.getDefiningOp() || vals.getDefiningOp()) { + return true; + } + + return false; + } + + Value computeScalarValue(Value vals, ArrayRef indices, + ConversionPatternRewriter &rewriter) const { + if (auto def = vals.getDefiningOp()) { + Value ptr = computeScalarValue(def.getPtr(), indices, rewriter); + Value offs = computeScalarValue(def.getOffset(), indices, rewriter); + return rewriter.create(def.getLoc(), ptr.getType(), ptr, offs); + } + + if (auto def = vals.getDefiningOp()) { + Value lhs = computeScalarValue(def.getLhs(), indices, rewriter); + Value rhs = computeScalarValue(def.getRhs(), indices, rewriter); + return rewriter.create(def.getLoc(), lhs.getType(), lhs, + rhs); + } + + if (auto def = vals.getDefiningOp()) { + return def.getSrc(); + } + + if (auto def = vals.getDefiningOp()) { + int32_t start = static_cast(def.getStart()); + assert(indices.size() == 1); + Type elemTy = cast(def.getType()).getElementType(); + return rewriter.create( + def.getLoc(), elemTy, + rewriter.getIntegerAttr(elemTy, start + indices[0])); + } + + return Value(); + } + Value extractMemRef(Location loc, Value ptr, ConversionPatternRewriter &rewriter) const { auto tensorTy = dyn_cast( From 77290744e2da033c2664e1052dfe04012103cb75 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 7 Aug 2024 18:32:34 -0500 Subject: [PATCH 074/165] Add pass to optimize masked loads and stores. (#96) Signed-off-by: Ilya Enkovich --- python/test/unit/cpu/test_opt.py | 30 ++ test/TritonCPU/optimize-masks.mlir | 35 ++ third_party/cpu/backend/compiler.py | 1 + .../cpu/include/TritonCPUTransforms/Passes.h | 1 + .../cpu/include/TritonCPUTransforms/Passes.td | 18 + .../lib/TritonCPUTransforms/CMakeLists.txt | 1 + .../lib/TritonCPUTransforms/OptimizeMasks.cpp | 393 ++++++++++++++++++ third_party/cpu/triton_cpu.cc | 3 + 8 files changed, 482 insertions(+) create mode 100644 test/TritonCPU/optimize-masks.mlir create mode 100644 third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp diff --git a/python/test/unit/cpu/test_opt.py b/python/test/unit/cpu/test_opt.py index 4051bfdd5bbe..538f0749868e 100644 --- a/python/test/unit/cpu/test_opt.py +++ b/python/test/unit/cpu/test_opt.py @@ -1,4 +1,5 @@ import os +import pytest import torch import triton @@ -35,3 +36,32 @@ def kernel(src, dst, BLOCK_SIZE: tl.constexpr): # Check TTCIR doesn't have pointer extraction from a tensor. ttcir = meta.asm["ttcir"] assert ttcir.count("extract") == 0 + + +@pytest.mark.parametrize("size", [32, 128, 135]) +@pytest.mark.parametrize("tile_size", [16]) +def test_optimize_tile_mask(size, tile_size, device): + + @triton.jit + def kernel(src, dst, size, TILE_SIZE: tl.constexpr): + for i in range(0, tl.cdiv(size, TILE_SIZE)): + offs = tl.arange(0, TILE_SIZE) + i * TILE_SIZE + mask = offs < size + x = tl.load(src + offs, mask=mask, other=0) + tl.store(dst + offs, x, mask=mask) + + src = torch.rand((size, ), dtype=torch.float32, device='cpu') + res = torch.empty_like(src) + meta = kernel[(1, )](src, res, size, TILE_SIZE=tile_size) + assert (src == res).all() + + # Check number of masked loads and stores. + tttcir = meta.asm["tttcir"] + masked_loads = tttcir.count("maskedload") + masked_stores = tttcir.count("maskedstore") + if size % tile_size == 0: + assert masked_loads == 0 + assert masked_stores == 0 + else: + assert masked_loads == 1 + assert masked_stores == 1 diff --git a/test/TritonCPU/optimize-masks.mlir b/test/TritonCPU/optimize-masks.mlir new file mode 100644 index 000000000000..2a013699cbac --- /dev/null +++ b/test/TritonCPU/optimize-masks.mlir @@ -0,0 +1,35 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-optimize-masks | FileCheck %s + +// Convert strided masked loads to scalar loads. + +// CHECK-LABEL: @remove_masks_in_for_loop +// CHECK: %[[VAL:.+]] = vector.load {{.+}} : memref<16xf32>, vector<16xf32> +// CHECK: vector.store %[[VAL]], {{.+}} : memref<16xf32>, vector<16xf32> + +module { + tt.func public @remove_masks_in_for_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %c15_i32 = arith.constant 15 : i32 + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<16xf32> + %0 = arith.addi %arg2, %c15_i32 : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + %2 = vector.splat %arg2 : vector<16xi32> + scf.for %arg3 = %c0_i32 to %1 step %c1_i32 : i32 { + %3 = arith.muli %arg3, %c16_i32 : i32 + %4 = vector.splat %3 : vector<16xi32> + %5 = arith.addi %4, %cst : vector<16xi32> + %6 = arith.cmpi slt, %5, %2 : vector<16xi32> + %7 = tt.addptr %arg0, %3 : !tt.ptr, i32 + %8 = triton_cpu.ptr_to_memref %7 : -> memref<16xf32> + %9 = vector.maskedload %8[%c0], %6, %cst_0 : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + %10 = tt.addptr %arg1, %3 : !tt.ptr, i32 + %11 = triton_cpu.ptr_to_memref %10 : -> memref<16xf32> + vector.maskedstore %11[%c0], %6, %9 : memref<16xf32>, vector<16xi1>, vector<16xf32> + } + tt.return + } +} diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index e002b6125651..7d1a7e924f3b 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -112,6 +112,7 @@ def make_tttcir(self, mod, metadata, opt): # TTCIR -> Target TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() + cpu.passes.ttcpuir.add_optimize_masks(pm) promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features # We don't have any lowering for mixed precision matmuls, so always use casts for now convert_mixed_precision_matmul = True diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index 0fab5b2e17ca..fffc485bbf3d 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -27,6 +27,7 @@ std::unique_ptr> createDecomposeFpConversions(); std::unique_ptr> createDecomposeFpConversions(bool decomposeBf16Conversions, bool decomposeFp8Conversions); +std::unique_ptr> createOptimizeMasks(); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUTransforms/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index a40c47ab0287..7c9c59b40091 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -57,4 +57,22 @@ def DecomposeFpConversions : Pass<"triton-cpu-decompose-fp-conversions", "mlir:: "mlir::triton::cpu::TritonCPUDialect"]; } +def OptimizeMasks : Pass<"triton-cpu-optimize-masks", "mlir::ModuleOp"> { + let summary = "Optimize masked memory accesses."; + let description = [{ + This pass tries to detect masked memory accesses with mask values that + can be proven to be all-ones or all-zeros. + }]; + + let options = [ + ]; + + let constructor = "mlir::triton::cpu::createOptimizeMasks()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt index 5a52aa7e86b6..afbe29d8aea5 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonCPUTransforms ConvertUnsupportedOps.cpp DecomposeFpConversions.cpp + OptimizeMasks.cpp DEPENDS TritonCPUTransformsPassIncGen diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp new file mode 100644 index 000000000000..fc230fdaa855 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp @@ -0,0 +1,393 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_OPTIMIZEMASKS +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +int64_t getDivisibility(Value val) { + BlockArgument blockArg = dyn_cast(val); + if (!blockArg) + return 1; + + Operation *argOp = blockArg.getOwner()->getParentOp(); + if (auto fn = dyn_cast(argOp)) { + Attribute attr = fn.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); + if (auto iattr = dyn_cast_or_null(attr)) { + return iattr.getInt(); + } + } + + return 1; +} + +bool isAlwaysDivisible(Value val, int64_t divisor) { + if (auto cst = val.getDefiningOp()) { + auto intAttr = dyn_cast(cst.getValue()); + return intAttr && (intAttr.getInt() % divisor == 0); + } + return getDivisibility(val) % divisor == 0; +} + +bool isAlwaysDivisible(Value val, Value divisor) { + if (auto cst = divisor.getDefiningOp()) { + auto intAttr = dyn_cast(cst.getValue()); + if (intAttr) + return isAlwaysDivisible(val, intAttr.getInt()); + } + return false; +} + +// Optimize cdiv pattern using divisibility hints. If value is known to be +// divisible by N then we can transform +// (val + K - 1) / K +// to +// val / K +// if N % K == 0 and val is not negative. Usually, we cannot prove value to be +// non-negative but still can apply transformation for contexts that assume +// positive value (e.g. as an upper bound in a for-loop with non-negative +// lower bound). +struct CdivToDiv : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::DivSIOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + auto addOpDef = lhs.getDefiningOp(); + auto divisorDef = rhs.getDefiningOp(); + if (!addOpDef || !divisorDef) + return failure(); + + arith::ConstantOp addCstDef; + Value addOtherVal; + if (addCstDef = addOpDef.getLhs().getDefiningOp()) + addOtherVal = addOpDef.getRhs(); + else if (addCstDef = addOpDef.getRhs().getDefiningOp()) + addOtherVal = addOpDef.getLhs(); + else + return failure(); + + int64_t divisorCst = cast(divisorDef.getValue()).getInt(); + int64_t addCst = cast(addCstDef.getValue()).getInt(); + if (divisorCst <= addCst) + return failure(); + + if (!isAlwaysDivisible(addOtherVal, divisorCst)) + return failure(); + + Value res = op.getResult(); + Value newRes = + rewriter.create(loc, addOtherVal, divisorDef); + int replaced = 0; + rewriter.replaceUsesWithIf(res, newRes, [&](OpOperand &use) { + if (auto forOp = dyn_cast(use.getOwner())) { + auto lowerDef = + forOp.getLowerBound().getDefiningOp(); + if (lowerDef && use.getOperandNumber() == 1 && + cast(lowerDef.getValue()).getInt() >= 0) { + ++replaced; + return true; + } + } + return false; + }); + if (!replaced) + return failure(); + + return success(); + } +}; + +// This pattern rewrites for-loops used for tiling to optimize out division +// and multiplication using divisibility hints. +// Typical tiled loop looks like: +// for i in range(0, tl.cdiv(size, TILE_SIZE)): +// offs = i * TILE_SIZE +// ... +// If size is known to be divisible by TILE_SIZE then it can be written as: +// for offs in range(0, size, TILE_SIZE): +// ... +// This pattern is used after an attempt to replace cdiv with a regular +// division. Possible input pattern is: +// %c0 = arith.constant 0 : index +// %c1 = arith.constant 1 : index +// %c16 = arith.constant 16 : index +// %init = arith.constant dense<0x00000000> : vector<16xf32> +// %1 = arith.divsi %arg4, %c16 +// %2 = scf.for %arg5 = %c0 to %1 step %c1 iter_args(%arg6 = %init) -> +// (vector<16xf32>) : i32 { +// %3 = arith.muli %arg5, %c16 : i32 +// ... +// } +// where %arg4 is known to be divisible by 16. The resulting code would be: +// %c0 = arith.constant 0 : index +// %c16 = arith.constant 16 : index +// %init = arith.constant dense<0x00000000> : vector<16xf32> +// %2 = scf.for %arg5 = %c0 to %arg4 step %c16 iter_args(%arg6 = %init) -> +// (vector<16xf32>) : i32 { +// ... +// } +// This removes division and simplifies the following analysis to optimize +// masked memory acccess for the tile. +struct ScaleInductionVariable : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value iv = op.getInductionVar(); + Value lower = op.getLowerBound(); + Value upper = op.getUpperBound(); + Value step = op.getStep(); + auto lowerDef = lower.getDefiningOp(); + auto upperDef = upper.getDefiningOp(); + if (!lowerDef || !upperDef) + return failure(); + + int64_t lowerVal = cast(lowerDef.getValue()).getInt(); + if (lowerVal < 0) + return failure(); + + // TODO: This is a strong requirement. With more generic value range + // analysis we should be able to not rely on this transformation. + if (!iv.hasOneUse()) + return failure(); + + auto ivUse = dyn_cast(*iv.getUsers().begin()); + if (!ivUse) + return failure(); + + Value scale = ivUse.getLhs() == iv ? ivUse.getRhs() : ivUse.getLhs(); + auto scaleDef = scale.getDefiningOp(); + auto divRhsDef = upperDef.getRhs().getDefiningOp(); + auto divLhs = upperDef.getLhs(); + if (!scaleDef || !divRhsDef) + return failure(); + + int64_t scaleVal = cast(scaleDef.getValue()).getInt(); + int64_t divisorVal = cast(divRhsDef.getValue()).getInt(); + if (scaleVal != divisorVal || !isAlwaysDivisible(divLhs, scaleVal) || + lowerVal % scaleVal != 0) + return failure(); + + // Build new lower bound. + Value newLower = lower; + if (lowerVal != 0) { + rewriter.setInsertionPointAfterValue(lower); + newLower = rewriter.create( + lower.getLoc(), lowerVal * scaleVal, lower.getType()); + } + // New Upper bound. + Value newUpper = divLhs; + // Build new step. + rewriter.setInsertionPoint(op); + auto newStep = rewriter.create(ivUse.getLoc(), step, scale); + + // Modify ForOp. + rewriter.startOpModification(op); + op.setLowerBound(newLower); + op.setUpperBound(newUpper); + op.setStep(newStep); + rewriter.finalizeOpModification(op); + + // Replace iv uses. + rewriter.replaceAllUsesWith(ivUse, iv); + + return success(); + } +}; + +// Build affine expression to express min/max value of the given SSA name. +// symbolTable is used to map SSA names to affine symbols. +AffineExpr buildMinOrMaxExpr(Value val, bool isSigned, bool isMax, + llvm::DenseMap &symbolTable) { + if (auto def = val.getDefiningOp()) { + return buildMinOrMaxExpr(def.getInput(), isSigned, isMax, symbolTable); + } else if (auto def = val.getDefiningOp()) { + auto attr = def.getValueAttr(); + if (auto intAttr = dyn_cast(attr)) + return getAffineConstantExpr(intAttr.getInt(), val.getContext()); + if (auto denseAttr = dyn_cast(attr)) { + auto valueBegin = denseAttr.value_begin(); + auto valueEnd = denseAttr.value_end(); + auto cmpVals = [isSigned](const APInt &lhs, const APInt &rhs) { + return isSigned ? lhs.slt(rhs) : lhs.ult(rhs); + }; + auto valueIt = isMax ? std::max_element(valueBegin, valueEnd, cmpVals) + : std::min_element(valueBegin, valueEnd, cmpVals); + return getAffineConstantExpr((*valueIt).getSExtValue(), val.getContext()); + } + } else if (auto def = val.getDefiningOp()) { + return buildMinOrMaxExpr(def.getLhs(), isSigned, isMax, symbolTable) + + buildMinOrMaxExpr(def.getRhs(), isSigned, isMax, symbolTable); + } else if (auto def = val.getDefiningOp()) { + return buildMinOrMaxExpr(def.getLhs(), isSigned, isMax, symbolTable) - + buildMinOrMaxExpr(def.getRhs(), isSigned, !isMax, symbolTable); + } else if (auto blockArg = dyn_cast(val)) { + auto op = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(op)) { + if (val == forOp.getInductionVar()) { + Value lower = forOp.getLowerBound(); + Value upper = forOp.getUpperBound(); + Value step = forOp.getStep(); + + // For min value return lower bound. + if (!isMax) + return buildMinOrMaxExpr(forOp.getLowerBound(), isSigned, isMax, + symbolTable); + + // For max value we use upper bound - 1 in generic case and bound - step + // if both bounds are divisible by the step. + if (isAlwaysDivisible(lower, step) && isAlwaysDivisible(upper, step)) { + return buildMinOrMaxExpr(upper, isSigned, isMax, symbolTable) - + buildMinOrMaxExpr(step, isSigned, false, symbolTable); + } + return buildMinOrMaxExpr(upper, isSigned, isMax, symbolTable) - + getAffineConstantExpr(1, val.getContext()); + } + } + } + + if (symbolTable.count(val)) + return getAffineSymbolExpr(symbolTable.at(val), val.getContext()); + + unsigned pos = symbolTable.size(); + symbolTable.insert(std::make_pair(val, pos)); + return getAffineSymbolExpr(pos, val.getContext()); +} + +// Check if vector mask is all-ones by checking compared values ranges. +// Only simplest cases are covered here, so affine expression is used +// to represent a range for now. +bool isAlwaysAllOnes(Value mask) { + auto maskDef = mask.getDefiningOp(); + if (!maskDef) + return false; + + auto pred = maskDef.getPredicate(); + if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) + return false; + + bool isSigned = + pred == arith::CmpIPredicate::sgt || pred == arith::CmpIPredicate::sge || + pred == arith::CmpIPredicate::sle || pred == arith::CmpIPredicate::slt; + llvm::DenseMap symbolTable; + AffineExpr maxOffs; + AffineExpr minLen; + if (pred == arith::CmpIPredicate::slt || pred == arith::CmpIPredicate::sle || + pred == arith::CmpIPredicate::ult || pred == arith::CmpIPredicate::ule) { + maxOffs = buildMinOrMaxExpr(maskDef.getLhs(), isSigned, true, symbolTable); + minLen = buildMinOrMaxExpr(maskDef.getRhs(), isSigned, false, symbolTable); + } else { + maxOffs = buildMinOrMaxExpr(maskDef.getRhs(), isSigned, true, symbolTable); + minLen = buildMinOrMaxExpr(maskDef.getLhs(), isSigned, false, symbolTable); + } + + // The mask is all-ones if max offset is always less than min length. + auto diff = maxOffs - minLen; + if (auto diffCst = dyn_cast(diff)) { + int64_t diffVal = diffCst.getValue(); + if (pred == arith::CmpIPredicate::slt || + pred == arith::CmpIPredicate::ult || + pred == arith::CmpIPredicate::sgt || pred == arith::CmpIPredicate::ugt) + return diffVal < 0; + else + return diffVal <= 0; + } + + return false; +} + +struct OptimizeMaskedLoad : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedLoadOp op, + PatternRewriter &rewriter) const override { + if (!isAlwaysAllOnes(op.getMask())) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), op.getBase(), + op.getIndices()); + return success(); + } +}; + +struct OptimizeMaskedStore : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MaskedStoreOp op, + PatternRewriter &rewriter) const override { + if (!isAlwaysAllOnes(op.getMask())) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getValueToStore(), + op.getBase(), op.getIndices()); + return success(); + } +}; + +struct OptimizeMasks + : public triton::cpu::impl::OptimizeMasksBase { + OptimizeMasks() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + // TODO: This pass optimizes out masks applying a set of very strict + // patterns. We should use more generic range and divisibility analysis + // to cover more cases and remove dependency on other transformations. + RewritePatternSet patterns1(context); + patterns1.add(context); + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns1)))) + return signalPassFailure(); + + RewritePatternSet patterns2(context); + patterns2.add(context); + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns2)))) + return signalPassFailure(); + + RewritePatternSet patterns3(context); + patterns3.add(context); + patterns3.add(context); + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns3)))) + return signalPassFailure(); + + // TODO: if masks removal failed for loads/stores in a for-loop, we might + // still optimize it using loop peeling. + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createOptimizeMasks() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 8a2b7f1642ff..6d08801e4775 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -54,6 +54,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_convert_atomic_ops", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertAtomicOps()); }); + m.def("add_optimize_masks", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createOptimizeMasks()); + }); m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm, bool promote_bf16_to_fp32, bool convert_mixed_precision_matmul, bool promote_lib_math_to_fp32) { From 76d9e65402671ad7329f5c9c045d5ec30179565d Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 8 Aug 2024 13:17:12 -0500 Subject: [PATCH 075/165] Fix incorrect casts in mask optimization. (#101) Signed-off-by: Ilya Enkovich --- python/test/unit/cpu/test_opt.py | 15 +++++++++++++++ .../cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp | 5 +++++ 2 files changed, 20 insertions(+) diff --git a/python/test/unit/cpu/test_opt.py b/python/test/unit/cpu/test_opt.py index 538f0749868e..32eed6fb7a99 100644 --- a/python/test/unit/cpu/test_opt.py +++ b/python/test/unit/cpu/test_opt.py @@ -65,3 +65,18 @@ def kernel(src, dst, size, TILE_SIZE: tl.constexpr): else: assert masked_loads == 1 assert masked_stores == 1 + + +# Regression test for compilation failure in masks optimization +def test_vec_cdiv(device): + + @triton.jit + def kernel(in_ptr, out_ptr): + offs = tl.arange(0, 16) + x = tl.load(in_ptr + offs) + res = (x + 15) // 16 + tl.store(out_ptr + offs, res) + + arg0 = torch.zeros((16, ), dtype=torch.int32) + arg1 = torch.empty_like(arg0) + kernel[(1, )](arg0, arg1) diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp index fc230fdaa855..271a8b28559e 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp @@ -71,6 +71,11 @@ struct CdivToDiv : public OpRewritePattern { LogicalResult matchAndRewrite(arith::DivSIOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); + + // Looking for a scalar op only. + if (isa(op.getType())) + return failure(); + Value lhs = op.getLhs(); Value rhs = op.getRhs(); auto addOpDef = lhs.getDefiningOp(); From 668866fef0b46ae4652305b3129f9c6805c8f924 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 8 Aug 2024 16:39:59 -0400 Subject: [PATCH 076/165] Add conversion for scf.while (#103) & add a test that covers the case where the loop's arguments are tensors, not just scalars --- python/test/unit/language/test_core.py | 24 ++++++------ .../ConvertControlFlowOps.cpp | 38 +++++++++++++++++++ 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 69abbbf0a7c6..81116cc30350 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -5628,27 +5628,29 @@ def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): def test_while(device): @triton.jit - def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): - init_i = tl.load(InitI) + def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ, BLOCKSIZE: tl.constexpr): + range = tl.arange(0, BLOCKSIZE) + init_i = tl.load(InitI + range) curr_i = init_i j = 0 # Check that init_i is not updated by the loop while j < tl.load(Bound): curr_i = curr_i + (j == tl.load(CutOff)) j += 1 - tl.store(OutInitI, init_i) - tl.store(OutI, curr_i) + tl.store(OutInitI + range, init_i) + tl.store(OutI + range, curr_i) tl.store(OutJ, j) - out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) - out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) - init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) - out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + size = 16 + out_i = to_triton(np.zeros((size, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((size, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((size, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((size, ), 0, dtype=np.int32), device=device) bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) - kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) - assert out_init_i[0] == init_i[0] - assert out_i[0] == init_i[0] + 1 + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j, size) + np.testing.assert_equal(to_numpy(out_init_i), to_numpy(init_i)) + np.testing.assert_equal(to_numpy(out_i), to_numpy(init_i + 1)) assert out_j[0] == bound[0] diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp index a0115a897734..491b647103a7 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp @@ -70,6 +70,34 @@ struct ForOpConversion : public OpConversionPattern { } }; +// This is borrowed from SCFWhilePattern in +// lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +class WhileOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = getTypeConverter(); + assert(converter); + SmallVector newResultTypes; + if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) + return failure(); + + auto newOp = rewriter.create(op.getLoc(), newResultTypes, + adaptor.getOperands()); + for (auto i : {0u, 1u}) { + auto &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + if (failed(rewriter.convertRegionTypes(&dstRegion, *converter))) + return rewriter.notifyMatchFailure(op, "could not convert body types"); + } + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + // This is borrowed from ConvertFIfOpTypes in // SCF/Transforms/StructuralTypeConversions.cpp // and @@ -132,9 +160,14 @@ struct ConvertControlFlowOps [&](Operation *op) -> std::optional { return typeConverter.isLegal(op); }); + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); { RewritePatternSet patterns(context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } @@ -154,9 +187,14 @@ struct ConvertControlFlowOps [&](Operation *op) -> std::optional { return typeConverter.isLegal(op); }); + convTarget.addDynamicallyLegalOp( + [&](Operation *op) -> std::optional { + return typeConverter.isLegal(op); + }); { RewritePatternSet patterns(context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } From cf99e44401ad6520e64af15b7c5641a56006e651 Mon Sep 17 00:00:00 2001 From: RuiqiGao Date: Thu, 8 Aug 2024 15:02:38 -0700 Subject: [PATCH 077/165] [TUTORIAL] Add bf16 matrix vector multiplication tutorial (#90) * Add bf16 matrix vector multiplication tutorial. * add gpu benchmark * Minor refactor * Add torch native in the benchmark * use torch.matmul for precision comparison --- .../matrix-vector-multiplication-bf16.py | 191 ++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 python/tutorials/matrix-vector-multiplication-bf16.py diff --git a/python/tutorials/matrix-vector-multiplication-bf16.py b/python/tutorials/matrix-vector-multiplication-bf16.py new file mode 100644 index 000000000000..c37ceb018ff5 --- /dev/null +++ b/python/tutorials/matrix-vector-multiplication-bf16.py @@ -0,0 +1,191 @@ +import torch + +import triton +import triton.language as tl + +BLOCK_SIZE_M = 16 +BLOCK_SIZE_N = 64 +USE_GPU = False +""" +Kernel for computing Y = A @ X, where A is a dense matrix with +M rows and N columns. +- Input X has shape (N,) +- A has shape (M, N) +- Output has shape (M,) +""" + + +@triton.jit +def gemv_kernel( + Y, + A, + X, + M, + N, + stride_am, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + start_m = tl.program_id(0) + rm = start_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = tl.arange(0, BLOCK_SIZE_N) + + A = A + (rm[:, None] * stride_am + rn[None, :]) + X = X + rn + + acc = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.float32) + for n in range(N, 0, -BLOCK_SIZE_N): + a = tl.load(A) + x = tl.load(X) + acc += tl.sum(a * x[None, :], axis=1) + A += BLOCK_SIZE_N + X += BLOCK_SIZE_N + + y = acc.to(tl.bfloat16) + Y = Y + rm + tl.store(Y, y) + + +def gemv( + weight: torch.Tensor, + x: torch.Tensor, + output: torch.Tensor, +): + assert weight.shape[1] == x.shape[0], "Incompatible dimensions" + assert weight.is_contiguous() and x.is_contiguous(), "Input and weight must be contiguous" + assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}" + + M, N = weight.shape + + # TODO: Currently masked load is not supported yet. + assert M % BLOCK_SIZE_M == 0 and N % BLOCK_SIZE_N == 0, "Masking currently not supported, Matrix dimensions must be multiples of block size" + + if output is None: + # Allocates output. + output = torch.empty((M, ), device=x.device, dtype=x.dtype) + else: + assert output.shape == (M, ) and output.dtype == x.dtype, "Incompatible output" + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), ) + + gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N) + + return output + + +torch.manual_seed(0) + +triton.runtime.driver.set_active_to_cpu() + +weight = torch.randn((512, 1024), device='cpu', dtype=torch.bfloat16) +x = torch.randn((1024), device='cpu', dtype=torch.bfloat16) +triton_output = gemv(weight, x, None) +compiled_matmul = torch.compile(torch.matmul) +# Note: torch.matmul for bf16 on Arm Linux will trigger error on old torch versions: +# RuntimeError: could not create a primitive descriptor for a matmul primitive +# So we recommend using torch 2.4.0 onwards. +torch_output = torch.matmul(weight, x) +#print(f"triton_cpu_output_with_{weight.dtype}_inputs={triton_output}") +#print(f"torch_cpu_output_with_{weight.dtype}_inputs={torch_output}") +rtol = 0 +if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol): + print("✅ TritonCPU and TorchCPU match") +else: + print("❌ TritonCPU and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + +LINE_VALS = [ + 'triton-cpu-single', 'triton-cpu', 'torch-cpu-native-single', 'torch-cpu-native', 'torch-cpu-compile-single', + 'torch-cpu-compile' +] +LINE_NAMES = [ + 'TritonCPU 1', 'TritonCPU', 'TorchCPU (native) 1', 'TorchCPU (native)', 'TorchCPU (compile) 1', 'TorchCPU (compile)' +] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--'), ('green', '-'), ('red', '--'), ('red', '-')] + +if USE_GPU and triton.runtime.driver.get_active_gpus(): + triton.runtime.driver.set_active_to_gpu() + weight = weight.to('cuda') + x = x.to('cuda') + triton_output = gemv(weight, x, None) + torch_output = torch.matmul(weight, x) + #print(f"triton_gpu_output_with_{weight.dtype}_inputs={triton_output}") + #print(f"torch_gpu_output_with_{weight.dtype}_inputs={torch_output}") + rtol = 0 + if torch.allclose(triton_output, torch_output, atol=1e-4, rtol=rtol): + print("✅ TritonGPU and TorchGPU match") + else: + print("❌ TritonGPU and TorchGPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + + LINE_VALS += ['triton-gpu', 'torch-gpu'] + LINE_NAMES += ['TritonGPU', 'TorchGPU'] + LINE_STYLES += [('pink', '-'), ('cyan', '-')] + +default_num_threads = torch.get_num_threads() + +# %% +# Seems like we're good to go! + +# %% +# Benchmark +# --------- + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "N"], # Argument names to use as an x-axis for the plot + x_vals=[(512 * i, 4096) for i in range(10, 51, 4)], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel='GFLOPS', # Label name for the y-axis. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'gemv-performance-bf16 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N})', + args={}, # Values for function arguments not in `x_names` and `y_name`. + )) +def benchmark(M, N, provider): + import os + + device = 'cpu' if 'cpu' in provider else 'cuda' + weight = torch.randn((M, N), device=device, dtype=torch.bfloat16) + x = torch.randn((N), device=device, dtype=torch.bfloat16) + + if device == 'cpu': + output = torch.empty((M), device=x.device, dtype=x.dtype) + triton.runtime.driver.set_active_to_cpu() + if 'single' in provider: + os.environ['TRITON_CPU_SINGLE_CORE'] = '1' + torch.set_num_threads(1) + else: + os.unsetenv('TRITON_CPU_SINGLE_CORE') + torch.set_num_threads(default_num_threads) + else: + output = None + triton.runtime.driver.set_active_to_gpu() + + quantiles = [0.5, 0.2, 0.8] + if provider == 'torch-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles) + elif provider == 'triton-gpu': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) + elif 'torch-cpu-native' in provider: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles, + is_cpu=True) + elif 'torch-cpu-compile' in provider: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_matmul(weight, x, out=output), + quantiles=quantiles, is_cpu=True) + elif 'triton-cpu' in provider: + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, is_cpu=True) + + perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) From 2c6453b669f413cf70ddf65fb6e4f83d3329e754 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 9 Aug 2024 10:23:09 -0500 Subject: [PATCH 078/165] Add an option to use sleef instead of libmvec. (#104) Signed-off-by: Ilya Enkovich --- python/test/unit/cpu/test_libmvec.py | 6 +- third_party/cpu/backend/compiler.py | 24 ++-- .../cpu/include/TritonCPUToLLVM/Passes.h | 2 + .../cpu/include/TritonCPUToLLVM/Passes.td | 8 +- .../cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp | 112 +++++++++++------- third_party/cpu/triton_cpu.cc | 4 +- 6 files changed, 101 insertions(+), 55 deletions(-) diff --git a/python/test/unit/cpu/test_libmvec.py b/python/test/unit/cpu/test_libmvec.py index 55dc7ec067e4..5873cc7f06a5 100644 --- a/python/test/unit/cpu/test_libmvec.py +++ b/python/test/unit/cpu/test_libmvec.py @@ -53,7 +53,8 @@ def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): num_vec_calls = 1 if data_size > 64: num_vec_calls = data_size / 64 - assert meta.asm["asm"].count("_ZGV") == num_vec_calls + prefix = "Sleef" if os.environ.get("TRITON_CPU_USE_SLEEF", "0") != "0" else "_ZGV" + assert meta.asm["asm"].count(prefix) == num_vec_calls @pytest.mark.parametrize("dtype_str", float_dtypes) @@ -93,4 +94,5 @@ def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): num_vec_calls = 1 if data_size > 64: num_vec_calls = data_size / 64 - assert meta.asm["asm"].count("_ZGV") == num_vec_calls + prefix = "Sleef" if os.environ.get("TRITON_CPU_USE_SLEEF", "0") != "0" else "_ZGV" + assert meta.asm["asm"].count(prefix) == num_vec_calls diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 7d1a7e924f3b..807d0bac337e 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -149,8 +149,10 @@ def make_llir(self, src, metadata, options): cpu.passes.ttcpuir.add_memory_op_to_llvmir(pm) cpu.passes.ttcpuir.add_atomic_ops_to_llvmir(pm) cpu.passes.ttcpuir.add_debug_ops_to_llvmir(pm) - if self.cpu_arch == "x86_64" and "avx512f" in self.cpu_features: - cpu.passes.ttcpuir.add_math_to_libmvec(pm) + use_sleef = os.environ.get("TRITON_CPU_USE_SLEEF", "0") != "0" + use_vec_math = os.environ.get("TRITON_CPU_USE_LIBMVEC", "1") != "0" + if (use_sleef or use_vec_math) and self.cpu_arch == "x86_64" and "avx512f" in self.cpu_features: + cpu.passes.ttcpuir.add_math_to_libmvec(pm, use_sleef) passes.convert.add_math_to_llvmir(pm) cpu.passes.ttcpuir.add_math_to_libm(pm) cpu.passes.ttcpuir.add_vector_to_llvmir(pm, options.enable_fast_math) @@ -196,14 +198,16 @@ def make_so(src, metadata, options): with tempfile.TemporaryDirectory() as tmpdir: asm_path = os.path.join(tmpdir, "kernel.s") Path(asm_path).write_text(src) - so = _build( - "kernel", - asm_path, - tmpdir, - cpu_driver.library_dirs, - cpu_driver.include_dirs, - ["gcc", "m", "TritonCPURuntime"], - ) + lib_dirs = cpu_driver.library_dirs + libs = ["gcc", "m", "TritonCPURuntime"] + # TRITON_CPU_USE_SLEEF=1 - use system libsleef + # TRITON_CPU_USE_SLEEF=path - use libsleef from the specified path + use_sleef = os.environ.get("TRITON_CPU_USE_SLEEF", "0") + if use_sleef != "0": + if os.path.isdir(use_sleef): + lib_dirs.append(use_sleef) + libs.append("sleef") + so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs) with open(so, "rb") as f: return f.read() diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index 556ac14bb669..fb366d6f82bc 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -27,6 +27,8 @@ std::unique_ptr> createLowerMultiReductionPass(); std::unique_ptr> createAtomicOpsToLLVMPass(); std::unique_ptr> createDebugOpsToLLVMPass(); std::unique_ptr> createMathToLibmvecPass(); +std::unique_ptr> +createMathToLibmvecPass(bool use_sleef); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUToLLVM/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index 10942f807248..6b3e0a8bd9d0 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -67,11 +67,17 @@ def DebugOpsToLLVM : Pass<"triton-cpu-debug-ops-to-llvm", "mlir::ModuleOp"> { } def MathToLibmvec : Pass<"triton-cpu-math-to-libmvec", "mlir::ModuleOp"> { - let summary = "Convert vector math operations to vector libm calls."; + let summary = "Convert vector math operations to vector libm or sleef calls."; let description = [{ }]; let constructor = "mlir::triton::cpu::createMathToLibmvecPass()"; + let options = [ + Option<"use_sleef", "use-sleef", + "bool", /*default*/"false", + "Use sleef library for vector math instead of libmvec.">, + ]; + let dependentDialects = ["mlir::vector::VectorDialect", "mlir::triton::cpu::TritonCPUDialect", "mlir::triton::TritonDialect", diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp index 5a035e217fcf..67c666904abd 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp @@ -140,10 +140,11 @@ struct VecOpToLibmvecCall : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; VecOpToLibmvecCall(MLIRContext *context, StringRef fp32FnBaseName, - StringRef fp64FnBaseName) + StringRef fp64FnBaseName, bool use_sleef) : OpRewritePattern(context) { this->fp32FnBaseName = fp32FnBaseName; this->fp64FnBaseName = fp64FnBaseName; + this->use_sleef = use_sleef; } LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { @@ -155,23 +156,12 @@ struct VecOpToLibmvecCall : public OpRewritePattern { if (!elemTy.isF32() && !elemTy.isF64()) return failure(); - auto baseName = elemTy.isF32() ? fp32FnBaseName : fp64FnBaseName; - int64_t vecSize = vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); - std::string isaPrefix; - if (vecSize == 128) { - isaPrefix = "b"; - } else if (vecSize == 256) { - isaPrefix = "d"; - } else if (vecSize == 512) { - isaPrefix = "e"; - } else { + auto fnName = use_sleef + ? getSleefName(elemTy.isF32(), vecTy.getNumElements()) + : getLibmvecName(elemTy.isF32(), vecTy.getNumElements(), + op->getOperands()); + if (fnName.empty()) return failure(); - } - std::string fnName = - "_ZGV" + isaPrefix + "N" + std::to_string(vecTy.getNumElements()); - for (auto operand : op->getOperands()) - fnName += "v"; - fnName += "_" + baseName; auto module = SymbolTable::getNearestSymbolTable(op); auto opFunc = dyn_cast_or_null( @@ -194,49 +184,86 @@ struct VecOpToLibmvecCall : public OpRewritePattern { return success(); } + std::string getLibmvecName(bool isFp32, int64_t numElems, + ValueRange ops) const { + auto baseName = isFp32 ? fp32FnBaseName : fp64FnBaseName; + int64_t vecSize = numElems * (isFp32 ? 32 : 64); + std::string isaPrefix; + if (vecSize == 128) { + isaPrefix = "b"; + } else if (vecSize == 256) { + isaPrefix = "d"; + } else if (vecSize == 512) { + isaPrefix = "e"; + } else { + return ""; + } + std::string fnName = "_ZGV" + isaPrefix + "N" + std::to_string(numElems); + for (auto operand : ops) + fnName += "v"; + fnName += "_" + baseName; + return fnName; + } + + std::string getSleefName(bool isFp32, int64_t numElems) const { + int64_t vecSize = numElems * (isFp32 ? 32 : 64); + if (vecSize < 128) + return ""; + auto baseName = isFp32 ? fp32FnBaseName : (fp64FnBaseName + "d"); + return "Sleef_" + baseName + std::to_string(numElems) + "_u10"; + } + private: std::string fp32FnBaseName; std::string fp64FnBaseName; + bool use_sleef; }; template void populatePatternsForOp(RewritePatternSet &patterns, StringRef fp32FnName, - StringRef fp64FnName) { + StringRef fp64FnName, bool use_sleef) { patterns.add>(patterns.getContext()); patterns.add>(patterns.getContext()); patterns.add>(patterns.getContext(), fp32FnName, - fp64FnName); + fp64FnName, use_sleef); } struct MathToLibmvecPass : public mlir::triton::cpu::impl::MathToLibmvecBase { - using MathToLibmvecBase::MathToLibmvecBase; + MathToLibmvecPass() = default; + + MathToLibmvecPass(bool use_sleef) { this->use_sleef = use_sleef; } void runOnOperation() override { Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); - populatePatternsForOp(patterns, "acosf", "acos"); - populatePatternsForOp(patterns, "acoshf", "acosh"); - populatePatternsForOp(patterns, "asinf", "asin"); - populatePatternsForOp(patterns, "asinhf", "asinh"); - populatePatternsForOp(patterns, "atanf", "atan"); - populatePatternsForOp(patterns, "atanhf", "atanh"); - populatePatternsForOp(patterns, "cbrtf", "cbrt"); - populatePatternsForOp(patterns, "cosf", "cos"); - populatePatternsForOp(patterns, "coshf", "cosh"); - populatePatternsForOp(patterns, "erff", "erf"); - populatePatternsForOp(patterns, "expf", "exp"); - populatePatternsForOp(patterns, "exp2f", "exp2"); - populatePatternsForOp(patterns, "logf", "log"); - populatePatternsForOp(patterns, "log2f", "log2"); - populatePatternsForOp(patterns, "log10f", "log10"); - populatePatternsForOp(patterns, "log1pf", "log1p"); - populatePatternsForOp(patterns, "sinf", "sin"); - populatePatternsForOp(patterns, "sinhf", "sinh"); - populatePatternsForOp(patterns, "tanf", "tan"); - populatePatternsForOp(patterns, "tanhf", "tanh"); + populatePatternsForOp(patterns, "acosf", "acos", use_sleef); + populatePatternsForOp(patterns, "acoshf", "acosh", + use_sleef); + populatePatternsForOp(patterns, "asinf", "asin", use_sleef); + populatePatternsForOp(patterns, "asinhf", "asinh", + use_sleef); + populatePatternsForOp(patterns, "atanf", "atan", use_sleef); + populatePatternsForOp(patterns, "atanhf", "atanh", + use_sleef); + populatePatternsForOp(patterns, "cbrtf", "cbrt", use_sleef); + populatePatternsForOp(patterns, "cosf", "cos", use_sleef); + populatePatternsForOp(patterns, "coshf", "cosh", use_sleef); + populatePatternsForOp(patterns, "erff", "erf", use_sleef); + populatePatternsForOp(patterns, "expf", "exp", use_sleef); + populatePatternsForOp(patterns, "exp2f", "exp2", use_sleef); + populatePatternsForOp(patterns, "logf", "log", use_sleef); + populatePatternsForOp(patterns, "log2f", "log2", use_sleef); + populatePatternsForOp(patterns, "log10f", "log10", + use_sleef); + populatePatternsForOp(patterns, "log1pf", "log1p", + use_sleef); + populatePatternsForOp(patterns, "sinf", "sin", use_sleef); + populatePatternsForOp(patterns, "sinhf", "sinh", use_sleef); + populatePatternsForOp(patterns, "tanf", "tan", use_sleef); + populatePatternsForOp(patterns, "tanhf", "tanh", use_sleef); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); @@ -253,6 +280,11 @@ std::unique_ptr> createMathToLibmvecPass() { return std::make_unique(); } +std::unique_ptr> +createMathToLibmvecPass(bool use_sleef) { + return std::make_unique(use_sleef); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 6d08801e4775..99975bb2f5d2 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -114,8 +114,8 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); }); - m.def("add_math_to_libmvec", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createMathToLibmvecPass()); + m.def("add_math_to_libmvec", [](mlir::PassManager &pm, bool use_sleef) { + pm.addPass(mlir::triton::cpu::createMathToLibmvecPass(use_sleef)); }); m.def("add_math_to_libm", [](mlir::PassManager &pm) { pm.addPass(mlir::createConvertMathToLibmPass()); From f727802f5482b088834127c329dfc18d4df2bb8c Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 9 Aug 2024 12:53:57 -0500 Subject: [PATCH 079/165] Enable fast math by default. (#108) Signed-off-by: Ilya Enkovich --- third_party/cpu/backend/compiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 807d0bac337e..d12698564e03 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -29,7 +29,7 @@ class CPUOptions: allow_fp8e4b15: bool = True enable_fp_fusion: bool = True max_num_imprecise_acc_default: int = 0 - enable_fast_math: bool = False + enable_fast_math: bool = True # TODO: We may introduce CPU-specific options like # of cores. @@ -58,7 +58,7 @@ def __init__(self, target: tuple) -> None: def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} if not "enable_fast_math" in args: - args["enable_fast_math"] = os.getenv("TRITON_CPU_FAST_MATH", "0") == "1" + args["enable_fast_math"] = os.getenv("TRITON_CPU_FAST_MATH", "1") != "0" return CPUOptions(**args) def pack_metadata(self, metadata): From b6009ebcab89fc503f3cdba84c6a4c6484fbceca Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Fri, 9 Aug 2024 14:15:10 -0400 Subject: [PATCH 080/165] Add more libdevice lowerings (#97) --- python/src/ir.cc | 8 ++++++ python/test/unit/cpu/test_libdevice.py | 8 ++++-- python/triton/language/extra/cpu/libdevice.py | 25 +++++++++++++++++++ .../ConvertElementwiseOps.cpp | 2 ++ 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index 61e1cdee92d6..da11b0f39087 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1617,6 +1617,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_expm1", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_cos", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); @@ -1669,6 +1673,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_log1p", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_log2", [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); diff --git a/python/test/unit/cpu/test_libdevice.py b/python/test/unit/cpu/test_libdevice.py index 22cb6286f3b8..6db1478d95f3 100644 --- a/python/test/unit/cpu/test_libdevice.py +++ b/python/test/unit/cpu/test_libdevice.py @@ -21,14 +21,18 @@ def is_cpu(): @pytest.mark.parametrize("dtype_str", float_dtypes) @pytest.mark.parametrize("math_fn", [ - "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "log", "log2", - "log10", "sin", "sinh", "tan", "tanh" + "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "expm1", "floor", + "log", "log1p", "log2", "log10", "rsqrt", "sin", "sinh", "sqrt", "tan", "tanh" ]) @pytest.mark.parametrize("size", [1, 4, 16, 64]) def test_libdevice(dtype_str, math_fn, size, device): if not is_cpu(): pytest.skip("This test is CPU-specific") + if dtype_str == "bfloat16": + if math_fn == "floor" or math_fn == "rsqrt": + pytest.skip("libgcc < 13 does not define __truncsfbf2, which this op needs") + @triton.jit def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): idxs = tl.arange(0, BLOCK_SIZE) diff --git a/python/triton/language/extra/cpu/libdevice.py b/python/triton/language/extra/cpu/libdevice.py index bc1926f4b893..9dbb4d682d42 100644 --- a/python/triton/language/extra/cpu/libdevice.py +++ b/python/triton/language/extra/cpu/libdevice.py @@ -61,6 +61,16 @@ def exp2(arg0, _builder=None): return core.tensor(_builder.create_exp2(arg0.handle), arg0.type) +@core.extern +def expm1(arg0, _builder=None): + return core.tensor(_builder.create_expm1(arg0.handle), arg0.type) + + +@core.extern +def floor(arg0, _builder=None): + return core.tensor(_builder.create_floor(arg0.handle), arg0.type) + + @core.extern def log(arg0, _builder=None): return core.tensor(_builder.create_log(arg0.handle), arg0.type) @@ -76,11 +86,26 @@ def log10(arg0, _builder=None): return core.tensor(_builder.create_log10(arg0.handle), arg0.type) +@core.extern +def log1p(arg0, _builder=None): + return core.tensor(_builder.create_log1p(arg0.handle), arg0.type) + + @core.extern def sin(arg0, _builder=None): return core.tensor(_builder.create_sin(arg0.handle), arg0.type) +@core.extern +def rsqrt(arg0, _builder=None): + return core.tensor(_builder.create_rsqrt(arg0.handle), arg0.type) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.tensor(_builder.create_sqrt(arg0.handle), arg0.type) + + @core.extern def sinh(arg0, _builder=None): return core.tensor(_builder.create_sinh(arg0.handle), arg0.type) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 8c2377babbac..7b3836898c2b 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -220,9 +220,11 @@ struct ConvertElementwiseOps patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); From 97341f03d22a66163b3de648d8e14e5d155cdd9e Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 9 Aug 2024 15:26:44 -0500 Subject: [PATCH 081/165] Enable rsqrt and floor for BF16. (#109) Signed-off-by: Ilya Enkovich --- python/test/unit/cpu/test_libdevice.py | 4 ---- .../ConvertUnsupportedOps.cpp | 1 + .../DecomposeFpConversions.cpp | 23 ++++++++++++++++++- 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/python/test/unit/cpu/test_libdevice.py b/python/test/unit/cpu/test_libdevice.py index 6db1478d95f3..5a37ec9af21d 100644 --- a/python/test/unit/cpu/test_libdevice.py +++ b/python/test/unit/cpu/test_libdevice.py @@ -29,10 +29,6 @@ def test_libdevice(dtype_str, math_fn, size, device): if not is_cpu(): pytest.skip("This test is CPU-specific") - if dtype_str == "bfloat16": - if math_fn == "floor" or math_fn == "rsqrt": - pytest.skip("libgcc < 13 does not define __truncsfbf2, which this op needs") - @triton.jit def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): idxs = tl.arange(0, BLOCK_SIZE) diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index ec51257caf67..af80b6d6c0f9 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -381,6 +381,7 @@ struct ConvertUnsupportedOps patterns.add>(context); patterns.add>(context); patterns.add>(context); + patterns.add>(context); patterns.add>(context); patterns.add>(context); patterns.add>(context); diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp index df1d0e34cd87..be5585347ba2 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -43,6 +43,25 @@ struct Fp32ToBf16Conversion : public OpRewritePattern { } }; +struct Bf16ToFp32Conversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ExtFOp op, + PatternRewriter &rewriter) const override { + Value src = op.getIn(); + if (!isFp32(op.getType()) || !isBf16(src.getType())) + return failure(); + + Location loc = op.getLoc(); + Value i16Src = op_bitcast(toInt16(src.getType()), src); + Value i32Src = op_zext(toInt32(src.getType()), i16Src); + Value i32Res = op_shl(i32Src, cst_like(i32Src, 16)); + Value res = op_bitcast(op.getType(), i32Res); + rewriter.replaceOp(op, res); + return success(); + } +}; + typedef std::function FpToFpConvFn; // Convert FP8 to FP16/FP32. @@ -501,8 +520,10 @@ struct DecomposeFpConversions ModuleOp mod = getOperation(); RewritePatternSet patterns(context); - if (decomposeBf16Conversions) + if (decomposeBf16Conversions) { patterns.add(context); + patterns.add(context); + } if (decomposeFp8Conversions) { patterns.add(context); patterns.add(context); From 9e94a21451750ce31f0b7c3af3611e048babbc66 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 9 Aug 2024 17:45:47 -0500 Subject: [PATCH 082/165] Remove specific dwarf version from -g option. (#110) Signed-off-by: Ilya Enkovich --- python/triton/runtime/build.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 72a66ddec32f..c44659be31bc 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -58,7 +58,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): if src.endswith(".cpp") or src.endswith(".cc"): cc_cmd += ["-std=c++17", "-fopenmp"] if src.endswith(".s"): - cc_cmd += ["-gdwarf-5"] + # This is required to properly parse .file directives + cc_cmd += ["-g"] if system == "Linux" and machine in ("aarch64", "arm64"): # On Arm backend, some CPU (neoverse-v2) needs to be specified through -mcpu cc_cmd += ["-mcpu=native"] From 29b9fdb3f25eab879a0d2b8c8cf8a73a30805035 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Tue, 13 Aug 2024 11:42:11 +0000 Subject: [PATCH 083/165] Enable `min_dot_size` Add min_dot_size to compiler.py to avoid new assertion. Signed-off-by: Dmitrii Makarenko --- third_party/cpu/backend/compiler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index d12698564e03..15bd3a882197 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -12,6 +12,10 @@ from triton.runtime.build import _build import triton.backends.cpu.driver as cpu_driver +def min_dot_size(target: GPUTarget): + # Other architectures will only support 16,16,16 + return lambda lhsType, rhsType: (4, 4, 4) + @dataclass(frozen=True) class CPUOptions: @@ -65,7 +69,7 @@ def pack_metadata(self, metadata): return metadata def get_codegen_implementation(self): - codegen_fns = dict() + codegen_fns = {"min_dot_size": min_dot_size(self.target)} return codegen_fns def load_dialects(self, ctx): From 8ff792f1b9d2c3f598bab1f98fbf217045ac38b5 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Tue, 13 Aug 2024 11:44:16 +0000 Subject: [PATCH 084/165] [Formatting] Apply formating This commit applies formating to rebased code Signed-off-by: Dmitrii Makarenko --- python/test/unit/language/test_core.py | 3 ++- third_party/cpu/backend/compiler.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 81116cc30350..5c432fd432a3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6495,7 +6495,8 @@ def matmul_kernel( # @pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("M, N, K", [(128, 256, 256)]) -@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), (64, 64, 64)]) +@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 128), + (64, 64, 64)] if not is_cpu() else [(32, 32, 128), (32, 32, 32)]) @pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv', 'float8e4b15']) @pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) def test_dot_max_num_imprecise_acc(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, in_type_str, low_precision_acc, device): diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 15bd3a882197..d6f1a6d82a83 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -12,6 +12,7 @@ from triton.runtime.build import _build import triton.backends.cpu.driver as cpu_driver + def min_dot_size(target: GPUTarget): # Other architectures will only support 16,16,16 return lambda lhsType, rhsType: (4, 4, 4) From 8e0d331a10fec5e6b1da284cd47b23a5de538227 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 13 Aug 2024 15:11:02 -0500 Subject: [PATCH 085/165] Remove is_cpu arg from do_bench. (#113) Signed-off-by: Ilya Enkovich --- python/triton/testing.py | 69 +++++++------------ python/tutorials/01-vector-add.py | 34 ++++----- python/tutorials/02-fused-softmax-cpu.py | 18 +++-- .../tutorials/03-matrix-multiplication-cpu.py | 15 ++-- .../matrix-vector-multiplication-bf16.py | 13 ++-- .../tutorials/matrix-vector-multiplication.py | 25 ++++--- 6 files changed, 82 insertions(+), 92 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index 54e7ee7c14ed..22adade12a35 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -11,26 +11,24 @@ from . import runtime -class Event: +class CPUDeviceInterface: - def __init__(self, is_cpu): - self.time = 0 - self.is_cpu = is_cpu - if not is_cpu: - import torch - self.cuda_event = torch.cuda.Event(enable_timing=True) + class Event: - def elapsed_time(self, end_event) -> float: - if self.is_cpu: + def __init__(self, enable_timing=True): + self.time = 0 + + def elapsed_time(self, end_event) -> float: return (end_event.time - self.time) * 1000 - else: - return self.cuda_event.elapsed_time(end_event.cuda_event) - def record(self): - if self.is_cpu: + def record(self): self.time = time.perf_counter() - else: - self.cuda_event.record() + + def __init__(self): + pass + + def synchronize(self): + pass def nvsmi(attrs): @@ -143,7 +141,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod return _summarize_statistics(ret, quantiles, return_mode) -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", is_cpu=False): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -163,46 +161,28 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m """ assert return_mode in ["min", "max", "mean", "median", "all"] - if not is_cpu: - di = runtime.driver.active.get_device_interface() + di = runtime.driver.active.get_device_interface() fn() - if not is_cpu: - di.synchronize() - - if not is_cpu: - cache = runtime.driver.active.get_empty_cache_for_benchmark() - if is_cpu: - # Currently, a typical L3 cache size for high-end server CPUs are ~400MB. - cache_size = 512 * 1024 * 1024 - cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cpu') - - if not is_cpu: - # Estimate the runtime of the function - start_event = di.Event(enable_timing=True) - end_event = di.Event(enable_timing=True) - else: - start_event = Event(is_cpu) - end_event = Event(is_cpu) + di.synchronize() + + cache = runtime.driver.active.get_empty_cache_for_benchmark() + start_event = di.Event(enable_timing=True) + end_event = di.Event(enable_timing=True) start_event.record() for _ in range(5): runtime.driver.active.clear_cache(cache) fn() end_event.record() - if not is_cpu: - di.synchronize() + di.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) n_repeat = max(1, int(rep / estimate_ms)) - if not is_cpu: - start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] - else: - start_event = [Event(is_cpu) for i in range(n_repeat)] - end_event = [Event(is_cpu) for i in range(n_repeat)] + start_event = [di.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [di.Event(enable_timing=True) for i in range(n_repeat)] # Warm-up for _ in range(n_warmup): fn() @@ -222,7 +202,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m end_event[i].record() # Record clocks di.synchronize() - times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] + + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) return _summarize_statistics(times, quantiles, return_mode) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index a6bfb8371f85..222ad5359c37 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -61,9 +61,10 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. # and (2) enqueue the above kernel with appropriate grid/block sizes: -def add(x: torch.Tensor, y: torch.Tensor, is_cpu): - # We need to preallocate the output. - output = torch.empty_like(x) +def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, device): + if output is None: + # We need to preallocate the output. + output = torch.empty_like(x) assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE n_elements = output.numel() # The SPMD launch grid denotes the number of kernel instances that run in parallel. @@ -74,7 +75,7 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): # - Each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. # - Don't forget to pass meta-parameters as keywords arguments. - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE if is_cpu else GPU_BLOCK_SIZE) + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE if device == 'cpu' else GPU_BLOCK_SIZE) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return output @@ -88,7 +89,7 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): x = torch.rand(size, device=DEVICE) y = torch.rand(size, device=DEVICE) output_torch_cpu = torch.add(x, y) -output_triton_cpu = add(x, y, None, is_cpu=True) +output_triton_cpu = add(x, y, None, device='cpu') print(output_torch_cpu) print(output_triton_cpu) print(f'The maximum difference between torch-cpu and triton-cpu is ' @@ -102,12 +103,12 @@ def add(x: torch.Tensor, y: torch.Tensor, is_cpu): triton.runtime.driver.set_active_to_gpu() x = x.to(DEVICE) y = y.to(DEVICE) - output_torch = x + y - output_triton = add(x, y) - print(output_torch) - print(output_triton) + output_torch_gpu = x + y + output_triton_gpu = add(x, y, None, device=DEVICE) + print(output_torch_gpu) + print(output_triton_gpu) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}') + f'{torch.max(torch.abs(output_torch_gpu - output_triton_gpu))}') LINE_VALS += ['triton-gpu', 'torch-gpu'] LINE_NAMES += ['TritonGPU', 'TorchGPU'] @@ -145,31 +146,30 @@ def benchmark(size, provider): y = torch.rand(size, device=DEVICE, dtype=torch.float32) if DEVICE == 'cpu': - is_cpu = True triton.runtime.driver.set_active_to_cpu() else: - is_cpu = False triton.runtime.driver.set_active_to_gpu() quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, is_cpu=is_cpu) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, device_type=DEVICE) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles, is_cpu=is_cpu) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles, + device_type=DEVICE) elif provider == 'torch-cpu': # Note that we preallocate the output buffer here to only measure the kernel performance # without a large chunk of memory allocation. output = torch.empty_like(x) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles, - is_cpu=is_cpu) + device_type=DEVICE) elif provider == 'triton-cpu-single': output = torch.empty_like(x) ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, - is_cpu=is_cpu) + device_type=DEVICE) elif provider == 'triton-cpu': output = torch.empty_like(x) ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, - is_cpu=is_cpu) + device_type=DEVICE) gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/02-fused-softmax-cpu.py b/python/tutorials/02-fused-softmax-cpu.py index 1fce4c78345f..355277d4bb9c 100644 --- a/python/tutorials/02-fused-softmax-cpu.py +++ b/python/tutorials/02-fused-softmax-cpu.py @@ -199,7 +199,6 @@ def benchmark(M, N, provider): x = torch.randn(M, N, device=device, dtype=torch.float32) if device == 'cpu': - is_cpu = True y = torch.empty_like(x) triton.runtime.driver.set_active_to_cpu() if 'single' in provider: @@ -207,30 +206,29 @@ def benchmark(M, N, provider): else: os.unsetenv('TRITON_CPU_SINGLE_CORE') else: - is_cpu = False y = None triton.runtime.driver.set_active_to_gpu() quantiles = [0.5, 0.2, 0.8] if provider == 'torch-cpu-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, - is_cpu=is_cpu) + device_type=device) if provider == 'torch-cpu-jit': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, is_cpu=is_cpu) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, device_type=device) if provider == 'torch-cpu-compile': compiled = torch.compile(naive_softmax) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles, is_cpu=is_cpu) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles, device_type=device) if provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, is_cpu=is_cpu) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, device_type=device) if provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, is_cpu=is_cpu) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, device_type=device) if provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, is_cpu=is_cpu) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, device_type=device) if provider == 'torch-gpu-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, - is_cpu=is_cpu) + device_type=device) if provider == 'torch-gpu-jit': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, is_cpu=is_cpu) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, device_type=device) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 937cbc652ba7..91378c5db836 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -378,19 +378,22 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles, + device_type=device) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles, + device_type=device) elif provider == 'torch-cpu-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, - is_cpu=True) + device_type=device) elif provider == 'torch-cpu-compile': compiled = torch.compile(torch.matmul) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles, + device_type=device) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, device_type=device) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, device_type=device) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/matrix-vector-multiplication-bf16.py b/python/tutorials/matrix-vector-multiplication-bf16.py index c37ceb018ff5..7993e4090b20 100644 --- a/python/tutorials/matrix-vector-multiplication-bf16.py +++ b/python/tutorials/matrix-vector-multiplication-bf16.py @@ -169,17 +169,20 @@ def benchmark(M, N, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles, + device_type=device) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, + device_type=device) elif 'torch-cpu-native' in provider: ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles, - is_cpu=True) + device_type=device) elif 'torch-cpu-compile' in provider: ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_matmul(weight, x, out=output), - quantiles=quantiles, is_cpu=True) + quantiles=quantiles, device_type=device) elif 'triton-cpu' in provider: - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, + device_type=device) perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/matrix-vector-multiplication.py b/python/tutorials/matrix-vector-multiplication.py index 288daab90178..5d44ddf9c2c2 100644 --- a/python/tutorials/matrix-vector-multiplication.py +++ b/python/tutorials/matrix-vector-multiplication.py @@ -173,33 +173,38 @@ def benchmark(M, N, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles, + device_type=device) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, + device_type=device) elif provider == 'torch-cpu-native' or provider == 'torch-cpu-2d-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles, - is_cpu=True) + device_type=device) elif provider == 'torch-cpu-compile' or provider == 'torch-cpu-2d-compile': compiled = torch.compile(torch.matmul) ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(weight, x, out=output), quantiles=quantiles, - is_cpu=True) + device_type=device) elif provider == 'torch-cpu-transpose-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(x, weight, out=output), quantiles=quantiles, - is_cpu=True) + device_type=device) elif provider == 'torch-cpu-transpose-compile': compiled = torch.compile(torch.matmul) ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x, weight, out=output), quantiles=quantiles, - is_cpu=True) + device_type=device) elif provider == 'torch-cpu-linear': weight = torch.nn.Linear(N, M, bias=False, device=weight.device, dtype=weight.dtype) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles, device_type=device) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, + device_type=device) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, + device_type=device) elif provider == 'triton-cpu-linear': # torch.nn.Linear.forward does not take preallocated output buffer, so we also do no provide output buffer for fair comparison - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, None), quantiles=quantiles, is_cpu=True) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, None), quantiles=quantiles, + device_type=device) perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) From 0e5fb8fb1ef88d6a66d6160b40e4f4850943e5fd Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 13 Aug 2024 16:37:52 -0500 Subject: [PATCH 086/165] Enable few more tutorials for CPU (#114) * Enable extern functions tutorial for CPU. Signed-off-by: Ilya Enkovich * Enable low memory dropout tutorial for CPU. Signed-off-by: Ilya Enkovich * Support layer norm tutorial for CPU. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- python/tutorials/04-low-memory-dropout.py | 1 + python/tutorials/05-layer-norm.py | 13 ++++++++----- python/tutorials/07-extern-functions.py | 3 ++- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index 3dd84da47e6f..391330fff094 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -72,6 +72,7 @@ def dropout(x, x_keep, p): return output +device = triton.runtime.driver.active.get_current_target().backend # Input tensor x = torch.randn(size=(10, ), device=DEVICE) # Dropout mask diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 5be07a9ea7d2..76693a85b7f4 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -290,6 +290,9 @@ def backward(ctx, dy): layer_norm = LayerNorm.apply +device = triton.runtime.driver.active.get_current_target().backend +# Torch doesn't support operations in float16 on CPU so use float32 instead +dtype = torch.float32 if device == 'cpu' else torch.flaot16 def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE): @@ -328,7 +331,7 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE): styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', plot_name='layer-norm-backward', - args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, + args={'M': 4096, 'dtype': dtype, 'mode': 'backward'}, )) def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device=DEVICE): # create data @@ -356,18 +359,18 @@ def y_fwd(): # forward pass if mode == 'forward': gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) - ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) + ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500, device_type=device) # backward pass if mode == 'backward': y = y_fwd() gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) # noqa: F811, E704 ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, - grad_to_none=[x], rep=500) + grad_to_none=[x], rep=500, device_type=device) return gbps(ms), gbps(max_ms), gbps(min_ms) -test_layer_norm(1151, 8192, torch.float16) -bench_layer_norm.run(save_path='.', print_data=True) +test_layer_norm(1151, 8192, dtype, device=device) +bench_layer_norm.run(save_path='.', print_data=True, device=device) # %% # References diff --git a/python/tutorials/07-extern-functions.py b/python/tutorials/07-extern-functions.py index 800563701ff0..03e44a72b3b7 100644 --- a/python/tutorials/07-extern-functions.py +++ b/python/tutorials/07-extern-functions.py @@ -49,12 +49,13 @@ def asin_kernel( # ----------------------------------------- # We can use the default libdevice library path encoded in `triton/language/math.py` +device = triton.runtime.driver.active.get_current_target().backend + torch.manual_seed(0) size = 98432 x = torch.rand(size, device=DEVICE) output_triton = torch.zeros(size, device=DEVICE) output_torch = torch.asin(x) -assert x.is_cuda and output_triton.is_cuda n_elements = output_torch.numel() grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) From 869b5ab4780434009f65561bcd7f17356d05a271 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 13 Aug 2024 16:44:00 -0500 Subject: [PATCH 087/165] Pass device type to do_bench in autotuner. (#115) Signed-off-by: Ilya Enkovich --- .github/workflows/build-test.yml | 3 ++- python/triton/runtime/autotuner.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 04417f6c5485..46865e1b556e 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -85,7 +85,8 @@ jobs: python/test/unit/language/test_conversions.py \ python/test/unit/cpu/test_libdevice.py \ python/test/unit/cpu/test_libmvec.py \ - python/test/unit/cpu/test_opt.py + python/test/unit/cpu/test_opt.py \ + python/test/unit/runtime/test_autotuner.py - name: Run lit tests run: | diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 339b79529537..23ff224998cd 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -164,7 +164,8 @@ def kernel_call(): self.post_hook(full_nargs, exception=None) try: - return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + device = driver.active.get_current_target().backend + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), device_type=device) except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e: if verbose: print(f"Autotuning failed with {e}") From 96fc92edc93734b4b5cc7fa46876d93205f7ca1a Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 13 Aug 2024 22:09:34 -0500 Subject: [PATCH 088/165] Fix indices extraction from block pointer. (#116) Signed-off-by: Ilya Enkovich --- .../test/unit/language/test_block_pointer.py | 18 ++++++++++++++++++ .../cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp | 7 ++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index 1f2f5b5e995f..fb2002101bb3 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -65,6 +65,24 @@ def test_block_copy(dtypes_str, n, padding_option, boundary_check, device): assert torch.all(torch.isnan(b[n // 2:n])) +def test_block_copy2d(device): + + @triton.jit + def kernel(in_ptr, out_ptr, M: tl.constexpr, N: tl.constexpr, BLOCK_M: tl.constexpr): + block_offset = tl.program_id(0) * BLOCK_M + in_block_ptr = tl.make_block_ptr(base=in_ptr, shape=(M, N), strides=(N, 1), offsets=(block_offset, 0), + block_shape=(BLOCK_M, N), order=(1, 0)) + out_block_ptr = tl.make_block_ptr(base=out_ptr, shape=(M, N), strides=(N, 1), offsets=(block_offset, 0), + block_shape=(BLOCK_M, N), order=(1, 0)) + x = tl.load(in_block_ptr) + tl.store(out_block_ptr, x) + + inp = torch.randn((256, 16), device=device, dtype=torch.float32) + res = torch.empty_like(inp) + kernel[(16, )](inp, res, M=16, N=16, BLOCK_M=16) + assert (inp == res).all() + + @triton.jit def matmul_no_scf_with_advance_kernel( # a_ptr, b_ptr, c_ptr, # diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp index 68d7231039c5..a3fbf20a713e 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp @@ -98,11 +98,8 @@ struct ExtractIndicesOpConversion SmallVector indices; for (int64_t i = 0; i < rank; i++) { - Value offs = rewriter.create( - loc, i64Ty, tensorPtrStruct, SmallVector{1, i}); - Value stride = rewriter.create( - loc, i64Ty, tensorPtrStruct, SmallVector{3, i}); - indices.push_back(rewriter.create(loc, offs, stride)); + indices.push_back(rewriter.create( + loc, i64Ty, tensorPtrStruct, SmallVector{1, i})); } rewriter.replaceOp(op, indices); From d1e748cfc880ec9728ea83bb6160e8cc13b73b02 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Tue, 13 Aug 2024 22:04:58 -0700 Subject: [PATCH 089/165] [cpu] Rework device_print with triton_cpu.print and 1D vector printing (#99) * [cpu] Rework device_print with 1D vector printing * Update minor comments * Apply suggestions from code review Co-authored-by: Jez Ng * A few fixes upon the previous suggestions * Refactoring + update comments --------- Co-authored-by: Jez Ng --- .../Dialect/TritonCPU/IR/TritonCPUOps.td | 20 ++ .../Dialect/TritonCPU/IR/TritonCPUTypes.td | 3 + lib/Dialect/TritonCPU/IR/CMakeLists.txt | 1 + lib/Dialect/TritonCPU/IR/Dialect.cpp | 3 - lib/Dialect/TritonCPU/IR/Ops.cpp | 18 ++ third_party/cpu/backend/compiler.py | 1 + .../cpu/include/TritonToTritonCPU/Passes.h | 1 + .../cpu/include/TritonToTritonCPU/Passes.td | 13 ++ .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 196 +++++++++++------- .../cpu/lib/TritonCPUToLLVM/TypeConverter.cpp | 15 +- .../cpu/lib/TritonCPUToLLVM/TypeConverter.h | 1 + .../lib/TritonCPUTransforms/OptimizeMasks.cpp | 4 +- .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 1 + .../lib/TritonToTritonCPU/ConvertDebugOps.cpp | 100 +++++++++ third_party/cpu/runtime/cpu_runtime.cpp | 192 ++++++++++++++++- third_party/cpu/triton_cpu.cc | 3 + 16 files changed, 487 insertions(+), 85 deletions(-) create mode 100644 lib/Dialect/TritonCPU/IR/Ops.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index 712826d02f91..6bcca9ec0d5b 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -76,4 +76,24 @@ def TTC_PtrToMemRefOp : TTC_Op<"ptr_to_memref", [NoMemoryEffect]> { let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; } +def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +def TTC_PrintOp : TTC_Op<"print", [MemoryEffects<[MemWrite]>]> { + let summary = "Print at most a single scalar or vector (converted from tensor) on each line"; + + let description = [{ + For converting tensor types to vector types. + It only takes a single scalar or vector (tensor) element. + }]; + + let arguments = (ins StrAttr:$prefix, BoolAttr:$hex, + Variadic>:$val); + + let assemblyFormat = [{ + $prefix attr-dict (`:` $val^ `:` type($val))? + }]; + + let hasVerifier = 1; +} + #endif diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td index ea31f877dab3..4bd64213db4b 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td @@ -1,6 +1,7 @@ #ifndef TRITONCPU_TYPES #define TRITONCPU_TYPES +include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/TritonCPU/IR/TritonCPUDialect.td" include "mlir/IR/AttrTypeBase.td" @@ -23,4 +24,6 @@ def TTC_TokenType : TTC_TypeDef<"Token", "token"> { let skipDefaultBuilders = 1; } +def TTC_Vector : VectorOf<[TT_Float, TT_Int]>; + #endif diff --git a/lib/Dialect/TritonCPU/IR/CMakeLists.txt b/lib/Dialect/TritonCPU/IR/CMakeLists.txt index 67bf1bb1b9d4..c0b6f0f7be24 100644 --- a/lib/Dialect/TritonCPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonCPU/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonCPUIR Dialect.cpp + Ops.cpp Types.cpp DEPENDS diff --git a/lib/Dialect/TritonCPU/IR/Dialect.cpp b/lib/Dialect/TritonCPU/IR/Dialect.cpp index acd31c07290f..41a4c62bda45 100644 --- a/lib/Dialect/TritonCPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonCPU/IR/Dialect.cpp @@ -67,9 +67,6 @@ void TritonCPUDialect::initialize() { >(); } -#define GET_OP_CLASSES -#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" - // verify TritonCPU ops LogicalResult TritonCPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { diff --git a/lib/Dialect/TritonCPU/IR/Ops.cpp b/lib/Dialect/TritonCPU/IR/Ops.cpp new file mode 100644 index 000000000000..d626ce3902a9 --- /dev/null +++ b/lib/Dialect/TritonCPU/IR/Ops.cpp @@ -0,0 +1,18 @@ +#include "mlir/IR/Builders.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#define GET_OP_CLASSES +#include "triton/Dialect/TritonCPU/IR/Ops.cpp.inc" + +// enum attribute definitions +#include "triton/Dialect/TritonCPU/IR/OpsEnums.cpp.inc" + +namespace mlir::triton::cpu { + +LogicalResult PrintOp::verify() { + if (getOperands().size() > 1) + return emitOpError("expects at most one operand"); + return success(); +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index d6f1a6d82a83..cf73ce52b0b2 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -106,6 +106,7 @@ def make_ttcir(mod, metadata, opt): cpu.passes.ttcpuir.add_convert_scan_op(pm) cpu.passes.ttcpuir.add_convert_cf_ops(pm) cpu.passes.ttcpuir.add_convert_atomic_ops(pm) + cpu.passes.ttcpuir.add_convert_debug_ops(pm) passes.common.add_cse(pm) passes.common.add_symbol_dce(pm) passes.common.add_canonicalizer(pm) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 303b99ce3c43..699518361e01 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -28,6 +28,7 @@ std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); std::unique_ptr> createConvertScanOp(); std::unique_ptr> createConvertAtomicOps(); +std::unique_ptr> createConvertDebugOps(); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonToTritonCPU/Passes.h.inc" diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index dfac926a9f5b..612ce135cc65 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -142,4 +142,17 @@ def ConvertAtomicOps : Pass<"triton-cpu-convert-atomic-ops", "mlir::ModuleOp"> { "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertDebugOps : Pass<"triton-cpu-convert-debug-ops", "mlir::ModuleOp"> { + let summary = "Convert Triton debug operations."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertDebugOps()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index 2bad397c9b77..c60da23b765a 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -4,6 +4,7 @@ #include "cpu/include/TritonCPUToLLVM/Passes.h" #include "mlir/Dialect/GPU/IR/GPUOps.h.inc" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -31,54 +32,6 @@ class TritonLLVMConversionTarget : public ConversionTarget { } }; -// The code for the print is similar to the GPU's TargetInfo.cpp. -LLVM::LLVMFuncOp getPrintfDeclaration(ConversionPatternRewriter &rewriter) { - auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); - StringRef funcName("printf"); - Operation *funcOp = moduleOp.lookupSymbol(funcName); - if (funcOp) - return cast(*funcOp); - - auto *context = rewriter.getContext(); - - // int printf(char* format, ...) - SmallVector argsType{ptr_ty(context)}; - auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, true); - - ConversionPatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(moduleOp.getBody()); - - auto op = rewriter.create(UnknownLoc::get(context), - funcName, funcType); - return op; -} - -void emitPrintf(ConversionPatternRewriter &rewriter, Value formatStrStart, - int /*formatStrByteCount*/, ValueRange args) { - auto loc = UnknownLoc::get(rewriter.getContext()); - SmallVector formatStrAndArgs{formatStrStart}; - for (auto arg : args) { - formatStrAndArgs.push_back(arg); - } - call(getPrintfDeclaration(rewriter), formatStrAndArgs); -} - -Value llPrintf(StringRef msg, ValueRange args, - ConversionPatternRewriter &rewriter, - int *formatStrByteCount = nullptr) { - assert(!msg.empty() && "printf with empty string not supported"); - llvm::SmallString<64> msgNewline(msg); - msgNewline.push_back('\n'); - msgNewline.push_back('\0'); - Value msgValue = - LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()), rewriter, - "printfFormat_", msgNewline); - emitPrintf(rewriter, msgValue, msgNewline.size_in_bytes(), args); - if (formatStrByteCount) - *formatStrByteCount = msgNewline.size_in_bytes(); - return msgValue; -} - // TODO: This code is the same as the GPU-backend code. Consider refactoring. std::string getFormatSubstr(Value value, bool hex = false, std::optional width = std::nullopt) { @@ -123,44 +76,139 @@ std::string getFormatSubstr(Value value, bool hex = false, return ""; } -// TritonCPU's device_print prints all values in the same line unlike GPUs -// and interpreter where each value is printed in a separate line. -struct PrintOpConversion : public ConvertOpToLLVMPattern { - explicit PrintOpConversion(LLVMTypeConverter &typeConverter) - : mlir::ConvertOpToLLVMPattern(typeConverter) {} +LLVM::LLVMFuncOp getPrintFuncDecl(ConversionPatternRewriter &rewriter, + bool printf) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName = printf ? "printf" : "triton_vector_print"; + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *ctx = rewriter.getContext(); + SmallVector argsType; + if (printf) + argsType = {ptr_ty(ctx)}; + else + argsType = {i32_ty, i32_ty, i32_ty, ptr_ty(ctx), + ptr_ty(ctx), i32_ty, i32_ty, i64_ty}; + + auto funcType = + LLVM::LLVMFunctionType::get(i32_ty, argsType, /*isVarArg*/ printf); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); +} + +void llPrintf(StringRef prefix, std::array pid, + std::optional arg, ConversionPatternRewriter &rewriter, + bool hex = false) { + assert(!prefix.empty() && "printf with empty string not supported"); + auto loc = UnknownLoc::get(rewriter.getContext()); + + std::string formatStr; + llvm::raw_string_ostream os(formatStr); + os << "(" << getFormatSubstr(pid[0]) << ", " << getFormatSubstr(pid[1]) + << ", " << getFormatSubstr(pid[2]) << ")" << prefix; + if (arg.has_value()) + os << getFormatSubstr(arg.value(), hex); + + llvm::SmallString<64> formatStrNewline(formatStr); + formatStrNewline.push_back('\n'); + formatStrNewline.push_back('\0'); + Value formatStrValue = + LLVM::addStringToModule(loc, rewriter, "printfFormat_", formatStrNewline); + + SmallVector allArgs{formatStrValue}; + for (auto elem : pid) + allArgs.push_back(elem); + if (arg.has_value()) + allArgs.push_back(arg.value()); + call(getPrintFuncDecl(rewriter, true), allArgs); +} + +void llVectorPrint(std::array pid, StringRef prefix, Value ptr, + bool isInteger, uint32_t bitWidth, int64_t numElem, + ConversionPatternRewriter &rewriter) { + assert(!prefix.empty()); + auto loc = UnknownLoc::get(rewriter.getContext()); + + llvm::SmallString<64> prefixStr(prefix); + prefixStr.push_back('\0'); + Value prefixValue = + LLVM::addStringToModule(loc, rewriter, "vectorPrintPrefix_", prefixStr); + + SmallVector allArgs; + for (auto elem : pid) + allArgs.push_back(elem); + allArgs.push_back(prefixValue); + allArgs.push_back(ptr); + allArgs.push_back(i32_val(isInteger)); + allArgs.push_back(i32_val(bitWidth)); + allArgs.push_back(i64_val(numElem)); + call(getPrintFuncDecl(rewriter, false), allArgs); +} + +bool usePrintf(triton::cpu::PrintOp op) { + // Simply use printf if no operand or the operand is scalar. + if (op.getNumOperands() == 0) + return true; + + // tt.print is already decomposed to triton_cpu.print per value. + assert(op.getNumOperands() == 1); + Type oprType = op.getOperands()[0].getType(); + return (oprType.isIntOrIndexOrFloat() || isa(oprType)); +} + +struct PrintOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + matchAndRewrite(triton::cpu::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto getPid = [&](int axis) { return getProgramId(op->getParentOfType(), axis); }; - SmallVector values = {getPid(0), getPid(1), getPid(2)}; - - std::string formatStr; - llvm::raw_string_ostream os(formatStr); - os << "(" << getFormatSubstr(values[0]) << ", " - << getFormatSubstr(values[1]) << ", " << getFormatSubstr(values[2]) - << ")" << op.getPrefix(); - - for (size_t i = 0; i < op.getNumOperands(); i++) { - auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter); - if (dyn_cast(op.getOperand(i).getType())) { - llvm_unreachable("Not implemented for tensor types"); + std::array pid = {getPid(0), getPid(1), getPid(2)}; + + if (usePrintf(op)) { + if (op.getNumOperands() == 0) { + llPrintf(op.getPrefix(), pid, std::nullopt, rewriter); + } else { + Value llOpr = adaptor.getOperands()[0]; + llPrintf(op.getPrefix(), pid, llOpr, rewriter, op.getHex()); } - - // Only support scalars for now. - assert(elems.size() == 1); - if (i != 0) { - os << ", "; + } else { + Value llOpr = adaptor.getOperands()[0]; + auto vecShapedType = cast(op.getOperands()[0].getType()); + // Currently, we only support 1D vector printing. + if (vecShapedType.getRank() == 1) { + + // To get the pointer of the vector, create an alloca and store it. + auto ptrType = ptr_ty(rewriter.getContext()); + auto ptr = rewriter.create(loc, ptrType, + llOpr.getType(), i32_val(1)); + rewriter.create(loc, llOpr, ptr); + + // TODO: Consider passing an encoded element type information instead of + // booleans and separate bit width. + llVectorPrint(pid, op.getPrefix(), ptr, + vecShapedType.getElementType().isInteger(), + vecShapedType.getElementTypeBitWidth(), + vecShapedType.getNumElements(), rewriter); + } else { + // TODO: support 2D+ vector printing. + std::string msg{op.getPrefix()}; + llvm::raw_string_ostream os(msg); + os << "<>"; + llPrintf(msg, pid, std::nullopt, rewriter); } - os << getFormatSubstr(elems[0], op.getHex()); - values.push_back(elems[0]); } - llPrintf(formatStr, values, rewriter); rewriter.eraseOp(op); return success(); } diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp index 144cb57b1115..821ea6f954b2 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp @@ -10,11 +10,8 @@ TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( addConversion([&](triton::PointerType type) -> std::optional { return convertTritonPointerType(type); }); - addConversion([this](RankedTensorType tensorTy) -> std::optional { - if (isa(tensorTy.getElementType())) - return VectorType::get(tensorTy.getShape(), - IntegerType::get(tensorTy.getContext(), 64)); - return std::nullopt; + addConversion([this](RankedTensorType type) -> std::optional { + return convertTritonTensorType(type); }); } @@ -41,3 +38,11 @@ Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( } return LLVM::LLVMPointerType::get(ctx); } + +Type TritonCPUToLLVMTypeConverter::convertTritonTensorType( + RankedTensorType type) { + if (isa(type.getElementType())) + return VectorType::get(type.getShape(), + IntegerType::get(type.getContext(), 64)); + llvm_unreachable("No tensor types are expected in TTCIR"); +} diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h index 35d74a9ec430..02123796ff37 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.h @@ -17,6 +17,7 @@ class TritonCPUToLLVMTypeConverter : public LLVMTypeConverter { const DataLayoutAnalysis *analysis = nullptr); Type convertTritonPointerType(triton::PointerType type); + Type convertTritonTensorType(RankedTensorType type); }; #endif diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp index 271a8b28559e..d113e6671531 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp @@ -85,9 +85,9 @@ struct CdivToDiv : public OpRewritePattern { arith::ConstantOp addCstDef; Value addOtherVal; - if (addCstDef = addOpDef.getLhs().getDefiningOp()) + if ((addCstDef = addOpDef.getLhs().getDefiningOp())) addOtherVal = addOpDef.getRhs(); - else if (addCstDef = addOpDef.getRhs().getDefiningOp()) + else if ((addCstDef = addOpDef.getRhs().getDefiningOp())) addOtherVal = addOpDef.getLhs(); else return failure(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index b200a47da92d..18e675044881 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonToTritonCPU ConvertAtomicOps.cpp ConvertControlFlowOps.cpp + ConvertDebugOps.cpp ConvertDotOp.cpp ConvertElementwiseOps.cpp ConvertElemManipOps.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp new file mode 100644 index 000000000000..cf6e6704bc28 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -0,0 +1,100 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTDEBUGOPS +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class DebugOpsConversionTarget : public ConversionTarget { +public: + explicit DebugOpsConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + + addIllegalOp(); + } +}; + +struct PrintOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + // It lowers to triton_cpu.print after converting tensor types to vectors. + // (tt.print doesn't accept vector types, so we have this intermediate op.) + if (op.getNumOperands() == 0) { + rewriter.create(loc, op.getPrefix(), op.getHex(), + ValueRange{}); + } else { + // triton_cpu.print takes up to one vector or scalar operand. It prints + // each value as a separate print call like the GPU and interpreter. + for (size_t i = 0; i < op.getNumOperands(); i++) { + Value opr = op.getOperands()[i]; + // TODO: Consider using memrefs for general N-dimensional vectors. + rewriter.create(loc, op.getPrefix(), op.getHex(), + rewriter.getRemappedValue(opr)); + } + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertDebugOps + : public triton::impl::ConvertDebugOpsBase { + using ConvertDebugOpsBase::ConvertDebugOpsBase; + + ConvertDebugOps() : ConvertDebugOpsBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + DebugOpsConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDebugOps() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index 0d69ca6c8ab7..3d232ddb2530 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -1,6 +1,196 @@ +#include +#include #include +#include +#include +#include +#include -void triton_assert(bool cond, char *c) { +#if defined(_MSC_VER) +#define EXPORT __declspec(dllexport) +#elif defined(__GNUC__) +#define EXPORT __attribute__((visibility("default"))) +#else +#define EXPORT +#endif + +namespace { + +// A poor man's Torch-like pretty print for tensors and vectors. +const int MAX_FLOAT_WIDTH = 8; +const int FLOAT_PREC = 4; +const int ELEMS_PER_LINE = 8; + +struct FormatInfo { + bool isInt; + int bitWidth; + int maxIntDigits; + bool hasNegative; + bool scientific; +}; + +template +std::pair +computeDigitInfoHelper(const void *array, size_t index) { + T elem = static_cast(array)[index]; + if (elem == 0) + return {1, false}; + return {static_cast(std::log10(std::abs(elem))) + 1, elem < 0}; +} + +std::pair computeDigitInfo(void *vec, int32_t isInt, + int32_t bitWidth, size_t index) { + + if (isInt == 0) { + if (bitWidth == 32) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 64) + return computeDigitInfoHelper(vec, index); + else + assert(false && "Unsupported bitWidth"); + } else { + // TODO: Handle signed types? + if (bitWidth == 64) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 32) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 16) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 8) + return computeDigitInfoHelper(vec, index); + else + assert(false && "Unsupported bitWidth"); + } +} + +FormatInfo getFormatInfo(void *vec, bool isInt, int32_t bitWidth, + int64_t numElem) { + // Compute the max/min widths for pretty printing. + int maxIntDigits = 0; + int minIntDigits = std::numeric_limits::max(); + bool hasNegative = false; + for (int64_t i = 0; i < numElem; ++i) { + auto [digits, negative] = computeDigitInfo(vec, isInt, bitWidth, i); + hasNegative |= negative; + maxIntDigits = std::max(maxIntDigits, digits); + minIntDigits = std::min(minIntDigits, digits); + } + // Fallback to the scientific format for certain cases. + bool scientific; + if (isInt) { + scientific = false; + } else { + scientific = maxIntDigits + 2 + (hasNegative ? 1 : 0) > MAX_FLOAT_WIDTH; + scientific |= maxIntDigits - minIntDigits > 3; + } + return {isInt, bitWidth, maxIntDigits, hasNegative, scientific}; +} + +template +void printElementHelper(std::stringstream &ss, const void *array, + size_t index) { + ss << static_cast(array)[index]; +} + +void printElement(std::stringstream &ss, const void *vec, size_t index, + bool isInt, int bitWidth) { + if (isInt == 0) { + switch (bitWidth) { + case 32: + printElementHelper(ss, vec, index); + break; + case 64: + printElementHelper(ss, vec, index); + break; + default: + assert(false && "Unsupported bitWidth"); + } + } else { + switch (bitWidth) { + case 64: + printElementHelper(ss, vec, index); + break; + case 32: + printElementHelper(ss, vec, index); + break; + case 16: + printElementHelper(ss, vec, index); + break; + case 8: + // TODO: Seems like not working well. Need to fix it. + printElementHelper(ss, vec, index); + break; + default: + assert(false && "Unsupported bitWidth"); + } + } +} + +void printFormattedElement(std::stringstream &ss, void *vec, size_t index, + const FormatInfo &formatInfo) { + int padding = 0; + auto [digits, negative] = + computeDigitInfo(vec, formatInfo.isInt, formatInfo.bitWidth, index); + if (!negative && formatInfo.hasNegative) + padding++; + if (formatInfo.scientific) { + ss << std::scientific << std::setw(MAX_FLOAT_WIDTH) + << std::setprecision(FLOAT_PREC) << std::string(padding, ' '); + printElement(ss, vec, index, formatInfo.isInt, formatInfo.bitWidth); + } else { + padding += formatInfo.maxIntDigits - digits; + ss << std::fixed << std::setprecision(FLOAT_PREC) + << std::string(padding, ' '); + printElement(ss, vec, index, formatInfo.isInt, formatInfo.bitWidth); + } +} +} // namespace + +extern "C" { + +EXPORT void triton_assert(bool cond, char *c) { if (!cond) fprintf(stderr, "%s\n", c); } + +// Print the pid prefix like the GPU ad interpreter. And vectors are printed +// similar to Torch's printing like the following: +// (1, 0, 0) x: [ -0.4963, -1.7682, 2.0885, 3.1320, -4.3074, 5.6341, +// -6.4901, 7.8964, -8.4556, -9.6323, -10.3489, -11.4017, +// -12.0223, 13.1689, 14.2939, -15.5185] +// +// TODO: Implement for higher dimension vectors. +EXPORT void triton_vector_print(int32_t pid0, int32_t pid1, int32_t pid2, + const char *prefix, void *vec, int32_t isInt, + int32_t bitWidth, int64_t numElem) { + + FormatInfo formatInfo = getFormatInfo(vec, isInt != 0, bitWidth, numElem); + + std::stringstream ss; + ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix << "["; + const size_t header = ss.str().size(); + + if (numElem <= ELEMS_PER_LINE) { + for (int i = 0; i < numElem; i++) { + printFormattedElement(ss, vec, i, formatInfo); + if (i != numElem - 1) + ss << ", "; + } + } else { + // TODO: Too many lines? Omit the middle lines. + for (int i = 0; i < numElem; i++) { + printFormattedElement(ss, vec, i, formatInfo); + if (i == numElem - 1) + break; + if (i % ELEMS_PER_LINE == ELEMS_PER_LINE - 1) { + ss << ",\n" << std::string(header, ' '); + } else { + ss << ", "; + } + } + } + ss << "]\n"; + std::cout << ss.str() << std::flush; +} + +} // extern "C" diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 99975bb2f5d2..90c886b6ec85 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -54,6 +54,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_convert_atomic_ops", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertAtomicOps()); }); + m.def("add_convert_debug_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDebugOps()); + }); m.def("add_optimize_masks", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createOptimizeMasks()); }); From 1f97f149ce43334aa55b6dfbfb0058d7f7fda0c2 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Wed, 14 Aug 2024 21:27:29 +0200 Subject: [PATCH 090/165] [Pytests] Add several suits (#106) * [Pytests] Add several suits This commits adds to testing several of already working suits. * [Pytest] Support CPU device Enable suits for cpu device. - language - python/test/unit/language/test_random.py - python/test/unit/language/test_standard.py - runtime - python/test/unit/runtime/test_bindings.py - python/test/unit/runtime/test_cache.py - python/test/unit/runtime/test_driver.py - python/test/unit/runtime/test_jit.py - python/test/unit/runtime/test_launch.py python/test/unit/runtime/test_cache.py expects creation and usage of files, that doesn't works with multiple workers. --------- Signed-off-by: Dmitrii Makarenko --- .github/workflows/build-test.yml | 15 +++++++++++++-- python/test/unit/language/test_pipeliner.py | 7 ++++++- python/test/unit/language/test_standard.py | 8 +++++--- python/test/unit/runtime/test_cache.py | 9 ++++++--- python/triton/runtime/jit.py | 8 +++++++- 5 files changed, 37 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 46865e1b556e..f675e052acf6 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -83,10 +83,21 @@ jobs: python/test/unit/language/test_annotations.py \ python/test/unit/language/test_block_pointer.py \ python/test/unit/language/test_conversions.py \ + python/test/unit/language/test_compile_errors.py \ + python/test/unit/language/test_decorator.py \ + python/test/unit/language/test_pipeliner.py \ + python/test/unit/language/test_random.py \ + python/test/unit/language/test_standard.py \ + python/test/unit/runtime/test_bindings.py \ + python/test/unit/runtime/test_driver.py \ + python/test/unit/runtime/test_jit.py \ + python/test/unit/runtime/test_launch.py \ + python/test/unit/runtime/test_subproc.py \ + python/test/unit/runtime/test_autotuner.py \ + python/test/unit/runtime/test_cache.py \ python/test/unit/cpu/test_libdevice.py \ python/test/unit/cpu/test_libmvec.py \ - python/test/unit/cpu/test_opt.py \ - python/test/unit/runtime/test_autotuner.py + python/test/unit/cpu/test_opt.py - name: Run lit tests run: | diff --git a/python/test/unit/language/test_pipeliner.py b/python/test/unit/language/test_pipeliner.py index 824e75838912..fe09e57e724e 100644 --- a/python/test/unit/language/test_pipeliner.py +++ b/python/test/unit/language/test_pipeliner.py @@ -5,6 +5,7 @@ import triton import triton.language as tl import triton.tools.experimental_descriptor +from test_core import is_cpu from triton._internal_testing import is_cuda, is_hopper, is_hip_cdna, is_hip_mi200, is_hip @@ -273,7 +274,11 @@ def test_pipeline_matmul(scale, device): if scale: ref_out = dot_scale_ref(a, scale_a, b, a_type, b_type) else: - ref_out = torch.matmul(a, b) + if is_cpu(): + ref_out = torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(torch.float16) + else: + ref_out = torch.matmul(a, b) + # Bigger tolerance for AMD MI200 devices. # MI200 devices use reduced precision fp16 and bf16 and flush input and # output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices diff --git a/python/test/unit/language/test_standard.py b/python/test/unit/language/test_standard.py index df5784d92641..06e471710cea 100644 --- a/python/test/unit/language/test_standard.py +++ b/python/test/unit/language/test_standard.py @@ -3,7 +3,7 @@ import torch import triton.language as tl -from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random +from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random, is_cpu # --------------- # test maximum/minimum ops @@ -26,7 +26,8 @@ def test_maximum_minium(dtype, op, device): @pytest.mark.interpreter -@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize( + "M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]]) @pytest.mark.parametrize("descending", [False, True]) @pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_sort(M, N, descending, dtype_str, device): @@ -54,7 +55,8 @@ def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr @pytest.mark.interpreter -@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]]) +@pytest.mark.parametrize( + "M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]]) @pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16']) def test_flip(M, N, dtype_str, device): diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index c74e305e42bf..eb174835f5ae 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -8,7 +8,7 @@ import triton import triton.language as tl -from triton.runtime.jit import JITFunction +from triton.runtime.jit import JITFunction, get_device_key from triton._internal_testing import is_hip @@ -191,7 +191,7 @@ def kernel(X, i: tl.int32): x = torch.empty(1, dtype=torch.int32, device=device) - device = getattr(torch, device).current_device() + device_key = get_device_key() kernel[(1, )](x, 1) kernel[(1, )](x, 8) kernel[(1, )](x, 16) @@ -407,6 +407,9 @@ def test_jit_debug(device) -> None: def kernel(tmp): tl.device_assert(tl.load(tmp) == 1, "tmp == 1") + if device == "cpu": + pytest.skip('Device Assert is not yet supported on CPU') + device = getattr(torch, device).current_device() tmp = torch.tensor([1], dtype=torch.int32, device=device) assert len(kernel.device_caches[device][0]) == 0 @@ -474,7 +477,7 @@ def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr): tl.device_assert(idx < 32, "idx < 32") tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx)) - device = getattr(torch, device).current_device() + device = get_device_key() # get the serialized specialization data specialization_data = None diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 749f4870dfbe..f0454993c129 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -429,6 +429,12 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options type_canonicalisation_dict[v] = v +def get_device_key(): + target = driver.active.get_current_target() + device = driver.active.get_current_device() + return f"{target.backend}:{device}" + + class JITFunction(KernelInterface[T]): # Hook for inspecting compiled functions and modules cache_hook = None @@ -668,7 +674,7 @@ def preload(self, specialization_data): from ..compiler import compile, ASTSource import json import triton.language as tl - device = driver.active.get_current_device() + device_key = get_device_key() deserialized_obj = json.loads(specialization_data) if deserialized_obj['name'] != self.fn.__name__: raise RuntimeError( From 89d1b424f8faf855fa0f94bb13e47964a8039fbe Mon Sep 17 00:00:00 2001 From: RuiqiGao Date: Thu, 15 Aug 2024 09:21:02 -0700 Subject: [PATCH 091/165] Identify dot product pattern (mul followed by a sum) for bf16, and convert it to dot product intrinsics (#56) * Identify dot product pattern (mul followed by a sum) for bf16, and convert it to dot product intrinsics --- third_party/cpu/backend/compiler.py | 3 + .../include/TritonCPUTransforms/OptCommon.h | 17 ++ .../cpu/include/TritonCPUTransforms/Passes.h | 2 + .../cpu/include/TritonCPUTransforms/Passes.td | 17 ++ .../lib/TritonCPUTransforms/CMakeLists.txt | 1 + .../TritonCPUTransforms/ConvertDotProduct.cpp | 272 ++++++++++++++++++ third_party/cpu/triton_cpu.cc | 3 + 7 files changed, 315 insertions(+) create mode 100644 third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index cf73ce52b0b2..02bdc926a527 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -119,6 +119,9 @@ def make_tttcir(self, mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() cpu.passes.ttcpuir.add_optimize_masks(pm) + convert_bf16_dot_product = self.cpu_arch == "aarch64" and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features + if convert_bf16_dot_product: + cpu.passes.ttcpuir.add_convert_dot_product(pm) promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features # We don't have any lowering for mixed precision matmuls, so always use casts for now convert_mixed_precision_matmul = True diff --git a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h index 0fe6dc64c5b2..a2e94f894caf 100644 --- a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h +++ b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h @@ -2,6 +2,7 @@ #define TRITONCPU_CONVERSION_TRITONCPUOPT_OPTCOMMON_H #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" @@ -103,6 +104,22 @@ Value cstLike(Location loc, Value tySrc, T val, PatternRewriter &rewriter) { return fpCst(loc, tySrc.getType(), val, rewriter); } +inline Value shapeCast(Location loc, Value in, VectorType outTy, + PatternRewriter &rewriter) { + VectorType inTy = cast(in.getType()); + assert(outTy.getElementType() == inTy.getElementType()); + assert(outTy.getNumElements() == inTy.getNumElements()); + return rewriter.create(loc, outTy, in); +} + +inline Value shapeCast(Location loc, Value in, + std::initializer_list shapes, + PatternRewriter &rewriter) { + VectorType inTy = cast(in.getType()); + VectorType outTy = VectorType::get(shapes, inTy.getElementType()); + return shapeCast(loc, in, outTy, rewriter); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index fffc485bbf3d..426c93f93cce 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -29,6 +29,8 @@ createDecomposeFpConversions(bool decomposeBf16Conversions, bool decomposeFp8Conversions); std::unique_ptr> createOptimizeMasks(); +std::unique_ptr> createConvertDotProduct(); + #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUTransforms/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index 7c9c59b40091..fdc86563ef86 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -75,4 +75,21 @@ def OptimizeMasks : Pass<"triton-cpu-optimize-masks", "mlir::ModuleOp"> { "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertDotProduct : Pass<"triton-cpu-convert-dot-product", "mlir::ModuleOp"> { + let summary = "Convert dot product op."; + let description = [{ + This pass is used for indentifying dot product pattern + (for example, elementwise mul followed by a sum) and + converting it to dot product intrinsics like bfdot. + }]; + + let constructor = "mlir::triton::cpu::createConvertDotProduct()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt index afbe29d8aea5..86277b3f0490 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonCPUTransforms + ConvertDotProduct.cpp ConvertUnsupportedOps.cpp DecomposeFpConversions.cpp OptimizeMasks.cpp diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp new file mode 100644 index 000000000000..2255db8f64f8 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp @@ -0,0 +1,272 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "include/triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTDOTPRODUCT +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// TODO: support SVE and different vector width +// We currently only supported Arm Neon (128 bit vector). +// To support scalable vectors in SVE, we need to generate +// vector-length agnostic (VLA) code using vector.vscale. +// To support other platform (AVX512 for X86), we need to +// change the vectorBitWidth and the intrinsics. +constexpr int vectorBitWidth = 128; + +// This function is used to identify bf16 dot product (expressed by elementwise +// multiplication follwed by a sum). +// For example, the following pattern: tl.sum(a * x[None, :], axis=1) +// is used to express a dot product. +// Since x is broadcated for the elementwise multiplication. And tl.sum will +// cast its bf16 input to fp32. +// The pattern in MLIR will be: +// BroadcastOp -> MulFOp -> ExtFOp -> MultiDimReductionOp +bool isBf16DotProduct(vector::MultiDimReductionOp op, Value &matInput, + Value &vecInput, PatternRewriter &rewriter) { + Value src = op.getSource(); + Value acc = op.getAcc(); + auto srcTy = cast(src.getType()); + auto accTy = cast(acc.getType()); + auto resTy = cast(op.getType()); + + auto srcRank = srcTy.getRank(); + auto outNum = srcTy.getDimSize(0); + + if (resTy != accTy || srcRank != 2 || !isFp32(srcTy)) + return false; + + if (op.isReducedDim(0) || !op.isReducedDim(1)) + return false; + + if (op.getKind() != vector::CombiningKind::ADD) + return false; + + auto extFOp = src.getDefiningOp(); + + if (!extFOp || !extFOp->hasOneUse()) + return false; + + auto mulFOp = extFOp.getIn().getDefiningOp(); + + if (!mulFOp || !mulFOp->hasOneUse()) + return false; + + Value lhs = mulFOp.getLhs(); + Value rhs = mulFOp.getRhs(); + + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); + + if (!isBf16(lhsTy) || !isBf16(rhsTy)) + return false; + + const int lanes = + vectorBitWidth / lhsTy.getElementType().getIntOrFloatBitWidth(); + int64_t kVal = lhsTy.getDimSize(1); + + if (outNum < 1) + return false; + + // TODO: masking is not currrently supported + if (kVal % lanes != 0) + return false; + + if (outNum == 1) { + matInput = lhs; + vecInput = rhs; + } else { + vector::BroadcastOp broadCastOp; + if (rhs.getDefiningOp()) { + matInput = lhs; + broadCastOp = rhs.getDefiningOp(); + } else { + matInput = rhs; + broadCastOp = lhs.getDefiningOp(); + } + if (!broadCastOp || !broadCastOp->hasOneUse()) + return false; + vecInput = broadCastOp.getSource(); + } + + if (cast(vecInput.getType()).getDimSize(0) != 1 || + cast(matInput.getType()).getDimSize(0) != outNum) + return false; + + return true; +} + +struct ConvertMulSumToDotHorizontalSum + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, + PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + Value acc = op.getAcc(); + auto resTy = cast(op.getType()); + + Value matInput; + Value vecInput; + + bool isMatch = isBf16DotProduct(op, matInput, vecInput, rewriter); + if (!isMatch) + return failure(); + + // Once we get the matrix input (NxK) and vector input (K), + // where N is the output channel dimension + // and K is the reduction dimension. + // We will generate the following code to perform the dot product. + // For each output channel: + // we will pull 8 bf16 elements from the vector and matrix each time when + // we iterate over the K dimension. + // We will then use bfdot to perform sum-of-products on pairs of + // bf16 elements, accumulate and get 4 fp32 outputs. + // After the iteration over the K dimension finishes, we will use a + // horizontal sum (faddv) to sum the 4 fp32 into a single fp32. + // We will also share the vector input across the output channels + // to reduce the number of loads. + // For example, if we dot product a size 2x16 matrix with a size 16 vector, + // the pseudo code will be: + // matrix = shapecast(matrix, 2x2x8) + // vector = shapecast(vector, 2x8) + // out = zeros(2x4, fp32) + // out[0] = bfdot(out[0], matrix[0][0], vector[0]) + // out[1] = bfdot(out[1], matrix[1][0], vector[0]) + // out[0] = bfdot(out[0], matrix[0][1], vector[1]) + // out[1] = bfdot(out[1], matrix[1][1], vector[1]) + // out_0 = faddv(out[0]) : 4xfp32 -> fp32 + // out_1 = faddv(out[1]) : 4xfp32 -> fp32 + + auto matInputTy = cast(matInput.getType()); + auto vecInputTy = cast(vecInput.getType()); + + const int lanes = + vectorBitWidth / matInputTy.getElementType().getIntOrFloatBitWidth(); + const int resLanes = + vectorBitWidth / resTy.getElementType().getIntOrFloatBitWidth(); + int64_t kVal = matInputTy.getDimSize(1); + + // numOfOutputChannels is the number of output channels (N) + const int numOfOutputChannels = matInputTy.getDimSize(0); + // numOfBfdotOps is the number of bfdots needed for each output channel. + const int numOfBfdotOps = kVal / lanes; + + matInput = shapeCast(loc, matInput, + {numOfOutputChannels, numOfBfdotOps, lanes}, rewriter); + vecInput = shapeCast(loc, vecInput, {numOfBfdotOps, lanes}, rewriter); + + SmallVector outRes(numOfOutputChannels); + SmallVector mats(numOfOutputChannels); + + Type outResTy = VectorType::get(resLanes, resTy.getElementType()); + + Value zeroRes = rewriter.create( + loc, outResTy, rewriter.getZeroAttr(outResTy)); + for (int64_t outIdx = 0; outIdx < numOfOutputChannels; outIdx += 1) { + outRes[outIdx] = zeroRes; + // Intermediate array to store each row of the input matrix. + mats[outIdx] = rewriter.create(loc, matInput, outIdx); + } + + SmallVector resultTypes = {outResTy}; + // TODO: this intrinsic is hard-coded for Arm Neon + llvm::StringRef bfdotIntrinsic("llvm.aarch64.neon.bfdot.v4f32.v8bf16"); + SmallVector args; + + for (int64_t idx = 0; idx < numOfBfdotOps; idx += 1) { + auto subVec = rewriter.create(loc, vecInput, idx); + for (int64_t outIdx = 0; outIdx < numOfOutputChannels; outIdx += 1) { + auto subMat = + rewriter.create(loc, mats[outIdx], idx); + args = {outRes[outIdx], subMat, subVec}; + // bfdot instruction: + // https://developer.arm.com/documentation/ddi0602/2024-06/SIMD-FP-Instructions/BFDOT--vector---BFloat16-floating-point-dot-product--vector-- + // LLVM fast math flags: + // https://llvm.org/docs/LangRef.html#fast-math-flags + // This bfdot intrinsic will perform an unfused sum-of-products of each + // pair of adjacent bf16 elements in the source vectors (8 bf16), and + // output 4 fp32 elements. + auto callIntrOp = rewriter.create( + loc, resultTypes, bfdotIntrinsic, args, LLVM::FastmathFlags::fast); + outRes[outIdx] = callIntrOp.getResult(0); + } + } + + Value res = rewriter.create(loc, resTy, + rewriter.getZeroAttr(resTy)); + + resultTypes = {resTy.getElementType()}; + // TODO: this intrinsic is hard-coded for Arm Neon + llvm::StringRef horizSumIntrinsic("llvm.aarch64.neon.faddv.f32.v4f32"); + for (int64_t outIdx = 0; outIdx < numOfOutputChannels; outIdx += 1) { + args = {outRes[outIdx]}; + // This horizontal sum intrinsic will sum all fp32 elements in the source + // vector into a single fp32 element + auto callIntrOp = rewriter.create( + loc, resultTypes, horizSumIntrinsic, args, LLVM::FastmathFlags::fast); + res = rewriter.create(loc, callIntrOp.getResult(0), res, + outIdx); + } + + if (!isZeroConst(acc)) { + res = rewriter.create(loc, res, acc); + } + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertDotProduct + : public triton::cpu::impl::ConvertDotProductBase { + ConvertDotProduct() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + + patterns.add(context); + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotProduct() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 90c886b6ec85..618fa6994051 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -60,6 +60,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_optimize_masks", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createOptimizeMasks()); }); + m.def("add_convert_dot_product", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDotProduct()); + }); m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm, bool promote_bf16_to_fp32, bool convert_mixed_precision_matmul, bool promote_lib_math_to_fp32) { From e2247f24e9b1a67774d29fd6bf8b1b6a5666dcb5 Mon Sep 17 00:00:00 2001 From: RuiqiGao Date: Fri, 16 Aug 2024 09:14:01 -0700 Subject: [PATCH 092/165] Add optional packing for converting bf16 dot product. (#118) * Add optional packing for converting bf16 dot product. * Resolve review comments --- third_party/cpu/backend/compiler.py | 3 +- .../cpu/include/TritonCPUTransforms/Passes.h | 2 + .../cpu/include/TritonCPUTransforms/Passes.td | 6 + .../TritonCPUTransforms/ConvertDotProduct.cpp | 230 +++++++++++++++++- third_party/cpu/triton_cpu.cc | 5 +- 5 files changed, 237 insertions(+), 9 deletions(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 02bdc926a527..79e704d859e0 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -121,7 +121,8 @@ def make_tttcir(self, mod, metadata, opt): cpu.passes.ttcpuir.add_optimize_masks(pm) convert_bf16_dot_product = self.cpu_arch == "aarch64" and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features if convert_bf16_dot_product: - cpu.passes.ttcpuir.add_convert_dot_product(pm) + use_horizontal_sum = os.getenv("TRITON_CPU_DOT_PROD_HORIZ_SUM", "1") == "1" + cpu.passes.ttcpuir.add_convert_dot_product(pm, use_horizontal_sum) promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features # We don't have any lowering for mixed precision matmuls, so always use casts for now convert_mixed_precision_matmul = True diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index 426c93f93cce..ec4c10498891 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -30,6 +30,8 @@ createDecomposeFpConversions(bool decomposeBf16Conversions, std::unique_ptr> createOptimizeMasks(); std::unique_ptr> createConvertDotProduct(); +std::unique_ptr> +createConvertDotProduct(bool useHorizontalSum); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUTransforms/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index fdc86563ef86..656eff4f4fe7 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -83,6 +83,12 @@ def ConvertDotProduct : Pass<"triton-cpu-convert-dot-product", "mlir::ModuleOp"> converting it to dot product intrinsics like bfdot. }]; + let options = [ + Option<"useHorizontalSum", "use-horizontal-sum", + "bool", /*default*/"true", + "Use Horizontal Sum kernel for the dot product (gemv). Otherwise use a kernel with packing.">, + ]; + let constructor = "mlir::triton::cpu::createConvertDotProduct()"; let dependentDialects = ["mlir::arith::ArithDialect", diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp index 2255db8f64f8..93cccb60fb1b 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp @@ -44,8 +44,9 @@ constexpr int vectorBitWidth = 128; // cast its bf16 input to fp32. // The pattern in MLIR will be: // BroadcastOp -> MulFOp -> ExtFOp -> MultiDimReductionOp -bool isBf16DotProduct(vector::MultiDimReductionOp op, Value &matInput, - Value &vecInput, PatternRewriter &rewriter) { +bool isBf16DotProduct(vector::MultiDimReductionOp op, bool useHorizontalSum, + Value &matInput, Value &vecInput, + PatternRewriter &rewriter) { Value src = op.getSource(); Value acc = op.getAcc(); auto srcTy = cast(src.getType()); @@ -85,11 +86,19 @@ bool isBf16DotProduct(vector::MultiDimReductionOp op, Value &matInput, const int lanes = vectorBitWidth / lhsTy.getElementType().getIntOrFloatBitWidth(); + const int resultLanes = + vectorBitWidth / resTy.getElementType().getIntOrFloatBitWidth(); int64_t kVal = lhsTy.getDimSize(1); if (outNum < 1) return false; + if (!useHorizontalSum) { + // TODO: masking is not currrently supported + if (outNum % resultLanes != 0) + return false; + } + // TODO: masking is not currrently supported if (kVal % lanes != 0) return false; @@ -132,7 +141,8 @@ struct ConvertMulSumToDotHorizontalSum Value matInput; Value vecInput; - bool isMatch = isBf16DotProduct(op, matInput, vecInput, rewriter); + bool isMatch = isBf16DotProduct(op, /*useHorizontalSum=*/true, matInput, + vecInput, rewriter); if (!isMatch) return failure(); @@ -166,7 +176,7 @@ struct ConvertMulSumToDotHorizontalSum const int lanes = vectorBitWidth / matInputTy.getElementType().getIntOrFloatBitWidth(); - const int resLanes = + const int resultLanes = vectorBitWidth / resTy.getElementType().getIntOrFloatBitWidth(); int64_t kVal = matInputTy.getDimSize(1); @@ -182,7 +192,7 @@ struct ConvertMulSumToDotHorizontalSum SmallVector outRes(numOfOutputChannels); SmallVector mats(numOfOutputChannels); - Type outResTy = VectorType::get(resLanes, resTy.getElementType()); + Type outResTy = VectorType::get(resultLanes, resTy.getElementType()); Value zeroRes = rewriter.create( loc, outResTy, rewriter.getZeroAttr(outResTy)); @@ -240,9 +250,208 @@ struct ConvertMulSumToDotHorizontalSum } }; +struct ConvertMulSumToDotPack + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp op, + PatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + Value acc = op.getAcc(); + auto resTy = cast(op.getType()); + + Value matInput; + Value vecInput; + + bool isMatch = isBf16DotProduct(op, /*useHorizontalSum=*/false, matInput, + vecInput, rewriter); + if (!isMatch) + return failure(); + + // Once we get the matrix input (NxK) and vector input (K), + // where N is the output channel dimension + // and K is the reduction dimension. + // We will generate the following code to perform the dot product. + // We will first transpose the matrix so that the output channel dimension + // is continuous, so we can store multiple output channels in one + // SIMD register. + // Then we will loop over the K dimension. + // For each iteration over K, we will pull 2 bf16 from the input vector. + // Inside the K loop, we will also iterate over the output channels. + // For each iteration over the output channel, we will pull + // 4 output channel (each containing 2 bf16). + // Then we will broadcast the 2 bf16 from the input vector, + // dot product it with the 4 output channels (each containing 2 bf16), + // and accumulate it with 4 outputs. + // We will iterate over N until all output channels are processed. + // Then we will move to the next 2 bf16 from the input vector (the K loop). + // We will also share the vector input across the output channels. + // For example, if we dot product a size 8x8 matrix with a size 8 vector, + // the generated pseudo code will be: + // Dimension: + // N: the output channel dimension + // n0: the number of SIMD registers needed to store the output + // -- N / 4 (2 in this case) + // n1: the number of outputs stored per SIMD register + // -- 4 + // K: the reduction dimension + // k0: the number of SIMD registers needed for the input vector + // -- K / 8 (1 in this case) + // k1: the number of lanes per SIMD register + // -- 4 + // k2: the number of bf16 elements per SIMD lane + // -- 2 + // matrix = shapecast(matrix, 8x4x2) + // shape: NxK -> Nx(k0xk1)xk2 + // matrix = tranpose(matrix, 1, 0, 2) : 8x4x2xbf16 -> 4x8x2xbf16 + // shape: Nx(k0xk1)xk2 -> (k0xk1)xNxk2 + // matrix = shapecast(matrix, 1x4x2x4x2xbf16) + // shape: (k0xk1)xNxk2 -> k0xk1xn0xn1xk2 + // vector = shapecast(vector, 1x4x2) + // shape: K -> k0xk1xk2 + // out = zeros(2x4, fp32) + // shape: n0xn1 + // subvec = broadcast(vector[0][0]) : 2xbf16 -> 4x2xbf16 + // shape: k2 -> k1xk2 + // out[0] = bfdot(out[0], matrix[0][0][0], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out[1] = bfdot(out[1], matrix[0][0][1], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // subvec = broadcast(vector[0][1]) : 2xbf16 -> 4x2xbf16 + // shape: k2 -> k1xk2 + // out[0] = bfdot(out[0], matrix[0][1][0], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out[1] = bfdot(out[1], matrix[0][1][1], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // subvec = broadcast(vector[0][2]) : 2xbf16 -> 4x2xbf16 + // shape: k2 -> k1xk2 + // out[0] = bfdot(out[0], matrix[0][2][0], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out[1] = bfdot(out[1], matrix[0][2][1], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // subvec = broadcast(vector[0][3]) : 2xbf16 -> 4x2xbf16 + // shape: k2 -> k1xk2 + // out[0] = bfdot(out[0], matrix[0][3][0], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out[1] = bfdot(out[1], matrix[0][3][1], subvec) + // shape: (n1, n1xk2, k1xk2) -> n1 + // out = shapecast(out, 8) : 2x4xfp32 -> 8xfp32 + // shape: n0xn1 -> N + + auto matInputTy = cast(matInput.getType()); + auto vecInputTy = cast(vecInput.getType()); + + const int lanes = + vectorBitWidth / matInputTy.getElementType().getIntOrFloatBitWidth(); + const int resultLanes = + vectorBitWidth / resTy.getElementType().getIntOrFloatBitWidth(); + int64_t kVal = matInputTy.getDimSize(1); + + // numOfOutputChannels is the number of output channels (N) + const int numOfOutputChannels = matInputTy.getDimSize(0); + // numOfOutputRegs is the number of SIMD registers needed to store the + // output. + const int numOfOutputRegs = numOfOutputChannels / resultLanes; + // numOfVecRegs is the number of SIMD registers needed for the + // input vector. + const int numOfVecRegs = kVal / lanes; + // numOfVecPairs is the number of pairs (pair of bf16 elements) for the + // input vector. + const int numOfVecPairs = numOfVecRegs * resultLanes; + + VectorType fullResTy = + VectorType::get({numOfOutputRegs, resultLanes}, resTy.getElementType()); + + VectorType subResTy = VectorType::get(resultLanes, resTy.getElementType()); + + acc = shapeCast(loc, acc, fullResTy, rewriter); + + Type inElemTy = matInputTy.getElementType(); + // Integer type for a pair of bf16 elements + Type pairTy = IntegerType::get(ctx, 32); + + vecInput = + shapeCast(loc, vecInput, {numOfVecRegs, resultLanes, 2}, rewriter); + // We bitcast here because we are pulling pairs of bf16 each time. + vecInput = rewriter.create( + loc, VectorType::get({numOfVecRegs, resultLanes, 1}, pairTy), vecInput); + vecInput = shapeCast(loc, vecInput, {numOfVecRegs, resultLanes}, rewriter); + + matInput = shapeCast(loc, matInput, {numOfOutputChannels, numOfVecPairs, 2}, + rewriter); + // We bitcast here because we are pulling pairs of bf16 each time. + matInput = rewriter.create( + loc, VectorType::get({numOfOutputChannels, numOfVecPairs, 1}, pairTy), + matInput); + matInput = shapeCast(loc, matInput, {numOfOutputChannels, numOfVecPairs}, + rewriter); + // Packing/Transposing the weight matrix so that + // the output channel is continuous + matInput = rewriter.create( + loc, matInput, SmallVector{1, 0}); + matInput = shapeCast( + loc, matInput, + {numOfVecRegs, resultLanes, numOfOutputRegs, resultLanes}, rewriter); + + Value res = rewriter.create( + loc, fullResTy, rewriter.getZeroAttr(fullResTy)); + SmallVector resultTypes = {subResTy}; + // TODO: this intrinsic is hard-coded for Arm Neon + llvm::StringRef bfdotIntrinsic("llvm.aarch64.neon.bfdot.v4f32.v8bf16"); + SmallVector args; + + SmallVector subRes(numOfOutputRegs); + for (int64_t outIdx = 0; outIdx < numOfOutputRegs; outIdx += 1) { + subRes[outIdx] = rewriter.create(loc, acc, outIdx); + } + for (int64_t idx = 0; idx < numOfVecRegs; idx += 1) { + Value fullVec = rewriter.create(loc, vecInput, idx); + for (int64_t vecIdx = 0; vecIdx < resultLanes; vecIdx += 1) { + // shuffle mask used to broadcast the 'vecIdx'th lane of fullVec + SmallVector shuffleMask(resultLanes, vecIdx); + // Broadcasting the 'vecIdx'th lane of fullVec + Value subVec = rewriter.create(loc, fullVec, fullVec, + shuffleMask); + subVec = rewriter.create( + loc, VectorType::get({lanes}, inElemTy), subVec); + for (int64_t outIdx = 0; outIdx < numOfOutputRegs; outIdx += 1) { + Value subMat = rewriter.create( + loc, matInput, SmallVector{idx, vecIdx, outIdx}); + subMat = rewriter.create( + loc, VectorType::get({lanes}, inElemTy), subMat); + args = {subRes[outIdx], subMat, subVec}; + // bfdot instruction: + // https://developer.arm.com/documentation/ddi0602/2024-06/SIMD-FP-Instructions/BFDOT--vector---BFloat16-floating-point-dot-product--vector-- + // LLVM fast math flags: + // https://llvm.org/docs/LangRef.html#fast-math-flags + // This bfdot intrinsic will perform an unfused sum-of-products of + // each pair of adjacent bf16 elements in the source vectors + // (8 bf16), and output 4 fp32 elements. + auto callIntrOp = rewriter.create( + loc, resultTypes, bfdotIntrinsic, args, + LLVM::FastmathFlags::fast); + subRes[outIdx] = callIntrOp.getResult(0); + } + } + } + + for (int64_t outIdx = 0; outIdx < numOfOutputRegs; outIdx += 1) { + res = rewriter.create(loc, subRes[outIdx], res, outIdx); + } + + res = shapeCast(loc, res, resTy, rewriter); + rewriter.replaceOp(op, res); + return success(); + } +}; + struct ConvertDotProduct : public triton::cpu::impl::ConvertDotProductBase { ConvertDotProduct() = default; + ConvertDotProduct(bool useHorizontalSum) { + this->useHorizontalSum = useHorizontalSum; + } void runOnOperation() override { MLIRContext *context = &getContext(); @@ -250,7 +459,11 @@ struct ConvertDotProduct RewritePatternSet patterns(context); - patterns.add(context); + if (useHorizontalSum) { + patterns.add(context); + } else { + patterns.add(context); + } if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) return signalPassFailure(); @@ -267,6 +480,11 @@ std::unique_ptr> createConvertDotProduct() { return std::make_unique(); } +std::unique_ptr> +createConvertDotProduct(bool useHorizontalSum) { + return std::make_unique(useHorizontalSum); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 618fa6994051..8f3608384aaf 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -60,8 +60,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_optimize_masks", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createOptimizeMasks()); }); - m.def("add_convert_dot_product", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertDotProduct()); + m.def("add_convert_dot_product", [](mlir::PassManager &pm, + bool useHorizontalSum) { + pm.addPass(mlir::triton::cpu::createConvertDotProduct(useHorizontalSum)); }); m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm, bool promote_bf16_to_fp32, From 5ae0b706dbe5c3ec19b2d11922a4844da51cffe8 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 21 Aug 2024 14:58:04 -0500 Subject: [PATCH 093/165] Add load/store scalarization through loops. (#119) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 4 - test/TritonCPU/convert-memory-ops.mlir | 2 +- third_party/cpu/backend/compiler.py | 2 +- .../cpu/include/TritonToTritonCPU/Passes.h | 2 + .../cpu/include/TritonToTritonCPU/Passes.td | 6 + .../ConvertUnsupportedOps.cpp | 58 ++- .../TritonToTritonCPU/ConvertMemoryOps.cpp | 479 ++++++++++++++++-- third_party/cpu/triton_cpu.cc | 7 +- 8 files changed, 514 insertions(+), 46 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5c432fd432a3..be8d914f9122 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3331,10 +3331,6 @@ def test_permute(dtype_str, shape, perm, num_ctas, device): if shape == (128, 128) and dtype_str == 'float32': pytest.skip("TODO Out of LDS for float32 with shape 128x128") - if is_cpu(): - # FIXME: compilation time for big shapes is too long - shape = tuple(dim // 4 for dim in shape) - # triton kernel @triton.jit def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): diff --git a/test/TritonCPU/convert-memory-ops.mlir b/test/TritonCPU/convert-memory-ops.mlir index c98747269fdc..32f8630cab84 100644 --- a/test/TritonCPU/convert-memory-ops.mlir +++ b/test/TritonCPU/convert-memory-ops.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops | FileCheck %s +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops=use-scalar-loops=false | FileCheck %s // Convert strided masked loads to scalar loads. diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 79e704d859e0..cbc5d60b3957 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -96,7 +96,7 @@ def make_ttcir(mod, metadata, opt): # TTIR -> TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - cpu.passes.ttcpuir.add_convert_memory_ops(pm) + cpu.passes.ttcpuir.add_convert_memory_ops(pm, True) cpu.passes.ttcpuir.add_convert_ptr_ops(pm) cpu.passes.ttcpuir.add_convert_elementwise_ops(pm) cpu.passes.ttcpuir.add_convert_elem_manip_ops(pm) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 699518361e01..ac2c03b6abf8 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -21,6 +21,8 @@ namespace cpu { std::unique_ptr> createConvertElementwiseOps(); std::unique_ptr> createConvertElemManipOps(); std::unique_ptr> createConvertMemoryOps(); +std::unique_ptr> +createConvertMemoryOps(bool useScalarLoops); std::unique_ptr> createConvertPtrOps(); std::unique_ptr> createConvertDotOp(); std::unique_ptr> createConvertControlFlowOps(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 612ce135cc65..161fec9babcd 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -10,6 +10,12 @@ def ConvertMemoryOps : Pass<"triton-cpu-convert-memory-ops", "mlir::ModuleOp"> { }]; let constructor = "mlir::triton::cpu::createConvertMemoryOps()"; + let options = [ + Option<"useScalarLoops", "use-scalar-loops", + "bool", /*default*/"true", + "Enable lowering of tensor loads and stores to scalar loops.">, + ]; + let dependentDialects = ["mlir::arith::ArithDialect", "mlir::memref::MemRefDialect", "mlir::vector::VectorDialect", diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index af80b6d6c0f9..d3d5ca95d70d 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -1,6 +1,7 @@ #include "cpu/include/TritonCPUTransforms/OptCommon.h" #include "cpu/include/TritonCPUTransforms/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -87,23 +88,30 @@ struct ConvertIToBf16ToFp32 : public OpRewritePattern { }; Value convertMemRefToI16(Value memRef, PatternRewriter &rewriter) { - Value res; MemRefType memRefTy = cast(memRef.getType()); - Type newMemRefTy = + if (memRefTy.getElementType().isInteger()) + return memRef; + + Value res; + MemRefType newMemRefTy = MemRefType::get(memRefTy.getShape(), rewriter.getI16Type(), memRefTy.getLayout(), memRefTy.getMemorySpace()); auto insPoint = rewriter.saveInsertionPoint(); rewriter.setInsertionPointAfter(memRef.getDefiningOp()); // Memory references for masked operations and transfers are always built - // with PtrToMemRefOp or ExtractMemRefOp. + // with PtrToMemRefOp, ExtractMemRefOp, or memref::AllocaOp. if (auto castOp = memRef.getDefiningOp()) { res = rewriter.create(memRef.getLoc(), newMemRefTy, castOp.getSrc()); - } else { - auto extractOp = memRef.getDefiningOp(); - assert(extractOp && "Unexpected memref producer"); + } else if (auto extractOp = memRef.getDefiningOp()) { res = rewriter.create(memRef.getLoc(), newMemRefTy, extractOp.getSrc()); + } else { + auto allocaOp = memRef.getDefiningOp(); + assert(allocaOp && "Unexpected memref producer"); + res = rewriter.create(allocaOp.getLoc(), newMemRefTy, + allocaOp.getAlignmentAttr()); + rewriter.replaceOp(allocaOp, res); } rewriter.restoreInsertionPoint(insPoint); return res; @@ -195,6 +203,42 @@ struct ConvertBf16TransferWriteOp } }; +struct ConvertBf16LoadOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::LoadOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getType())) + return failure(); + + Location loc = op.getLoc(); + Value newMemRef = convertMemRefToI16(op.getMemRef(), rewriter); + Value intVal = + rewriter.create(loc, newMemRef, op.getIndices()); + Value res = rewriter.create(loc, op.getType(), intVal); + rewriter.replaceOp(op, res); + return success(); + } +}; + +struct ConvertBf16StoreOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::StoreOp op, + PatternRewriter &rewriter) const override { + if (!isBf16(op.getValue().getType())) + return failure(); + + Location loc = op.getLoc(); + Value newMemRef = convertMemRefToI16(op.getMemRef(), rewriter); + Value intVal = rewriter.create( + loc, toInt16(op.getValue().getType()), op.getValue()); + rewriter.replaceOpWithNewOp(op, intVal, newMemRef, + op.getIndices()); + return success(); + } +}; + struct ConvertBf16Abs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -353,6 +397,8 @@ struct ConvertUnsupportedOps patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); + patterns.add(context); } patterns.add(context); if (convertMixedPrecisionMatmul) { diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index e872aa63fdff..dd9bba70ff1d 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -22,8 +22,10 @@ namespace mlir { namespace triton { +namespace cpu { #define GEN_PASS_DEF_CONVERTMEMORYOPS #include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace cpu } // namespace triton } // namespace mlir @@ -41,9 +43,11 @@ struct MemoryOpConversion : public OpConversionPattern { MemoryOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, - TypeConverter &typeConverter, MLIRContext *context) + TypeConverter &typeConverter, bool useScalarLoops, + MLIRContext *context) : OpConversionPattern(typeConverter, context), - axisAnalysis(axisInfoAnalysis), shapeAnalysis(shapeInfoAnalysis) {} + axisAnalysis(axisInfoAnalysis), shapeAnalysis(shapeInfoAnalysis), + genScalarLoops(useScalarLoops) {} Value extractScalarPointer(Location loc, Value ptrs, ArrayRef indices, @@ -64,18 +68,29 @@ struct MemoryOpConversion : public OpConversionPattern { } bool canComputeScalarValue(Value vals) const { - if (auto def = vals.getDefiningOp()) { - return canComputeScalarValue(def.getPtr()) && - canComputeScalarValue(def.getOffset()); - } - - if (auto def = vals.getDefiningOp()) { - return canComputeScalarValue(def.getLhs()) && - canComputeScalarValue(def.getRhs()); + auto def = vals.getDefiningOp(); + if (!def) + return false; + + if (isa(*def)) { + for (auto op : def->getOperands()) { + if (!canComputeScalarValue(op)) + return false; + } + return true; } - if (vals.getDefiningOp() || vals.getDefiningOp()) { + if (isa(*def)) return true; + + if (auto cst = dyn_cast(def)) { + if (auto denseVal = dyn_cast(cst.getValue())) { + return denseVal.isSplat(); + } + return false; } return false; @@ -83,19 +98,6 @@ struct MemoryOpConversion : public OpConversionPattern { Value computeScalarValue(Value vals, ArrayRef indices, ConversionPatternRewriter &rewriter) const { - if (auto def = vals.getDefiningOp()) { - Value ptr = computeScalarValue(def.getPtr(), indices, rewriter); - Value offs = computeScalarValue(def.getOffset(), indices, rewriter); - return rewriter.create(def.getLoc(), ptr.getType(), ptr, offs); - } - - if (auto def = vals.getDefiningOp()) { - Value lhs = computeScalarValue(def.getLhs(), indices, rewriter); - Value rhs = computeScalarValue(def.getRhs(), indices, rewriter); - return rewriter.create(def.getLoc(), lhs.getType(), lhs, - rhs); - } - if (auto def = vals.getDefiningOp()) { return def.getSrc(); } @@ -109,7 +111,160 @@ struct MemoryOpConversion : public OpConversionPattern { rewriter.getIntegerAttr(elemTy, start + indices[0])); } - return Value(); + if (auto def = vals.getDefiningOp()) { + // Find broadcasted dimensions and replace indices for those dimensions + // with 0 (broadcasted dimension always has size 1). + SmallVector newIndices; + auto sourceTy = cast(def.getSrc().getType()); + auto targetTy = cast(def.getType()); + assert(sourceTy.getRank() == indices.size() && "Mismatched rank"); + for (int64_t i = 0; i < sourceTy.getRank(); ++i) { + if (sourceTy.getShape()[i] != targetTy.getShape()[i]) + newIndices.push_back(0); + else + newIndices.push_back(indices[i]); + } + return computeScalarValue(def.getSrc(), newIndices, rewriter); + } + + if (auto def = vals.getDefiningOp()) { + // Remove index at expanded dimension. + SmallVector newIndices(indices); + newIndices.erase(newIndices.begin() + def.getAxis()); + return computeScalarValue(def.getSrc(), newIndices, rewriter); + } + + if (auto def = vals.getDefiningOp()) { + auto denseVal = cast(def.getValue()); + assert(denseVal.isSplat()); + auto scalarAttr = denseVal.getSplatValue(); + Value res = rewriter.create( + def.getLoc(), scalarAttr.getType(), scalarAttr); + return res; + } + + if (auto def = vals.getDefiningOp()) { + // Permute indices. + SmallVector newIndices; + auto order = def.getOrder(); + assert(indices.size() == order.size() && "Mismatched rank"); + for (auto idx : order) + newIndices.push_back(indices[idx]); + return computeScalarValue(def.getSrc(), newIndices, rewriter); + } + + // Generic case where we copy defining op with scalar operands. + auto def = vals.getDefiningOp(); + OperationState newState(def->getLoc(), def->getName()); + for (auto op : def->getOperands()) { + newState.operands.push_back(computeScalarValue(op, indices, rewriter)); + } + assert(def->getResults().size() == 1); + newState.types.push_back( + cast(def->getResultTypes()[0]).getElementType()); + newState.attributes = def->getAttrs(); + return rewriter.create(newState)->getResult(0); + } + + Value computeScalarValue(Value vals, ValueRange indices, + ConversionPatternRewriter &rewriter, + DenseMap &valMap) const { + if (valMap.count(vals)) + return valMap.at(vals); + + if (auto def = vals.getDefiningOp()) { + return def.getSrc(); + } + + if (auto def = vals.getDefiningOp()) { + auto denseVal = cast(def.getValue()); + assert(denseVal.isSplat()); + auto scalarAttr = denseVal.getSplatValue(); + Value res = rewriter.create( + def.getLoc(), scalarAttr.getType(), scalarAttr); + valMap[vals] = res; + return res; + } + + if (auto def = vals.getDefiningOp()) { + assert(indices.size() == 1); + int32_t start = static_cast(def.getStart()); + Type elemTy = cast(def.getType()).getElementType(); + Value startVal = rewriter.create( + def.getLoc(), elemTy, rewriter.getIntegerAttr(elemTy, start)); + Value index = indices[0]; + if (!elemTy.isIndex()) + index = + rewriter.create(def.getLoc(), elemTy, index); + Value res = + rewriter.create(def.getLoc(), elemTy, startVal, index); + valMap[vals] = res; + return res; + } + + if (auto def = vals.getDefiningOp()) { + // Find broadcasted dimensions and replace indices for those dimensions + // with 0 (broadcasted dimension has always size 1). + SmallVector newIndices; + auto sourceTy = cast(def.getSrc().getType()); + auto targetTy = cast(def.getType()); + assert(sourceTy.getRank() == indices.size() && "Mismatched rank"); + for (int64_t i = 0; i < sourceTy.getRank(); ++i) { + if (sourceTy.getShape()[i] != targetTy.getShape()[i]) + newIndices.push_back( + rewriter.create(def.getLoc(), 0)); + else + newIndices.push_back(indices[i]); + } + // The original cache is only used for the original set of indices. + DenseMap tmpValMap; + Value res = + computeScalarValue(def.getSrc(), newIndices, rewriter, tmpValMap); + valMap[vals] = res; + return res; + } + + if (auto def = vals.getDefiningOp()) { + // Remove index at expanded dimension. + SmallVector newIndices = indices; + newIndices.erase(newIndices.begin() + def.getAxis()); + // The original cache is only used for the original set of indices. + DenseMap tmpValMap; + Value res = + computeScalarValue(def.getSrc(), newIndices, rewriter, tmpValMap); + valMap[vals] = res; + return res; + } + + if (auto def = vals.getDefiningOp()) { + // Permute indices. + SmallVector newIndices; + auto order = def.getOrder(); + assert(indices.size() == order.size() && "Mismatched rank"); + for (auto idx : order) + newIndices.push_back(indices[idx]); + // The original cache is only used for the original set of indices. + DenseMap tmpValMap; + Value res = + computeScalarValue(def.getSrc(), newIndices, rewriter, tmpValMap); + valMap[vals] = res; + return res; + } + + // Generic case where we copy defining op with scalar operands. + auto def = vals.getDefiningOp(); + OperationState newState(def->getLoc(), def->getName()); + for (auto op : def->getOperands()) { + newState.operands.push_back( + computeScalarValue(op, indices, rewriter, valMap)); + } + assert(def->getResults().size() == 1); + newState.types.push_back( + cast(def->getResultTypes()[0]).getElementType()); + newState.attributes = def->getAttrs(); + Value res = rewriter.create(newState)->getResult(0); + valMap[vals] = res; + return res; } Value extractMemRef(Location loc, Value ptr, @@ -144,9 +299,249 @@ struct MemoryOpConversion : public OpConversionPattern { rewriter.getZeroAttr(resTy.getElementType()))); } + Value createAlloca(Location loc, MemRefType ty, Operation *before, + ConversionPatternRewriter &rewriter) const { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(before); + return rewriter.create( + loc, ty, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); + } + + // If tensor is not null and its element cannot be recomputed in a scalar + // loop, then store it to a temporary buffer. + Value maybeStoreVecToTempBuf(Location loc, Value vals, Value zeroIdx, + Operation *allocaPoint, + ConversionPatternRewriter &rewriter) const { + if (!vals || canComputeScalarValue(vals)) + return nullptr; + + auto vec = rewriter.getRemappedValue(vals); + auto vecTy = cast(vec.getType()); + auto elemTy = vecTy.getElementType(); + // Memref of i1 assumes one element per byte when we load/store element, + // but vector store (through transfer write) would write 1 bit per element. + if (elemTy.isInteger(1)) { + elemTy = rewriter.getI8Type(); + vec = rewriter.create( + loc, VectorType::get(vecTy.getShape(), elemTy), vec); + } + auto memRefTy = MemRefType::get(vecTy.getShape(), elemTy); + Value memRef = createAlloca(vals.getLoc(), memRefTy, allocaPoint, rewriter); + SmallVector indices(vecTy.getRank(), zeroIdx); + rewriter.create(vals.getLoc(), vec, memRef, + indices); + return memRef; + } + + // Load scalar element from a temporary buffer or recompute it if the + // buffer doesn't exist. + Value computeOrLoadScalarValue(Value vals, Value tmpVals, ValueRange indices, + ConversionPatternRewriter &rewriter, + DenseMap &valMap) const { + // Allow null value for easier handling of optional arguments. + if (!vals) + return nullptr; + + // Load value from a temp buffer if any. + if (tmpVals) { + Value val = + rewriter.create(vals.getLoc(), tmpVals, indices); + // If we load a pointer then additional cast is needed because tensor of + // pointers is transformed into a vector of integers. + auto elemTy = dyn_cast(vals.getType()).getElementType(); + if (isa(elemTy)) + val = rewriter.create(vals.getLoc(), elemTy, val); + // We need to transform loaded i8 back to i1. + else if (elemTy.isInteger(1)) + val = rewriter.create(val.getLoc(), + rewriter.getI1Type(), val); + return val; + } + + return computeScalarValue(vals, indices, rewriter, valMap); + } + + LogicalResult scalarizeWithLoop(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + auto loc = loadOp.getLoc(); + auto vecTy = + dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + + auto ptrs = loadOp.getPtr(); + auto mask = loadOp.getMask(); + auto other = loadOp.getOther(); + auto cache = loadOp.getCache(); + auto evict = loadOp.getEvict(); + auto isVolatile = loadOp.getIsVolatile(); + + // Create some reused constants. + Value zeroIdx = rewriter.create(loc, 0); + Value oneIdx = rewriter.create(loc, 1); + + // There is alloca_scope operation to control alloca scopes. But its usage + // in combination with nested SCF and multi-dimensional vectors make it + // impossible to lower scopes to LLVM using existing MLIR passes. For now, + // simply allocate temp memory in the function's region. + // TODO: Use alloc for big buffers and revisit alloca scoping. + Operation *allocaPoint = loadOp; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Allocate temp buffer for the result. Write the other value there if + // we cannot write it in a loop. + auto resMemRefTy = + MemRefType::get(vecTy.getShape(), vecTy.getElementType()); + Value resMemRef = createAlloca(loc, resMemRefTy, allocaPoint, rewriter); + bool storeOtherInLoop = mask; + if (other && !canComputeScalarValue(other)) { + SmallVector indices(vecTy.getRank(), zeroIdx); + rewriter.create( + loc, rewriter.getRemappedValue(other), resMemRef, indices); + storeOtherInLoop = false; + } + + // Store a tensor of pointers and mask into a temp buf if we can't + // compute them in a loop. + Value tmpPtrs = + maybeStoreVecToTempBuf(loc, ptrs, zeroIdx, allocaPoint, rewriter); + Value tmpMask = + maybeStoreVecToTempBuf(loc, mask, zeroIdx, allocaPoint, rewriter); + + // Create for-loops to iterate through all vector dimensions. + SmallVector forOps; + SmallVector ivs; + for (int64_t i = 0; i < vecTy.getRank(); ++i) { + Value upperBound = + rewriter.create(loc, vecTy.getShape()[i]); + auto forOp = + rewriter.create(loc, zeroIdx, upperBound, oneIdx); + forOps.push_back(forOp); + ivs.push_back(forOp.getInductionVar()); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + // Compute or load a scalar arguments. + DenseMap valMap; + Value scalarPtr = + computeOrLoadScalarValue(ptrs, tmpPtrs, ivs, rewriter, valMap); + Value scalarMask = + computeOrLoadScalarValue(mask, tmpMask, ivs, rewriter, valMap); + Value scalarOther; + if (storeOtherInLoop) { + if (other) { + scalarOther = computeScalarValue(other, ivs, rewriter, valMap); + } else { + scalarOther = rewriter.create( + loc, vecTy.getElementType(), + rewriter.getZeroAttr(vecTy.getElementType())); + } + } + + if (!mask) { + // Regular load case. + Value val = rewriter.create(loc, scalarPtr, cache, evict, + isVolatile); + rewriter.create(loc, val, resMemRef, ivs); + } else { + // Conditional load case + rewriter.create( + loc, scalarMask, + [&](OpBuilder &builder, Location loc) { + Value val = builder.create(loc, scalarPtr, cache, + evict, isVolatile); + builder.create(loc, val, resMemRef, ivs); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + if (storeOtherInLoop) + builder.create(loc, scalarOther, resMemRef, ivs); + builder.create(loc); + }); + } + + // Load vector from the temp storage and return it from alloca scope. + rewriter.setInsertionPointAfter(forOps.front()); + SmallVector indices(vecTy.getRank(), zeroIdx); + Value res = + rewriter.create(loc, vecTy, resMemRef, indices); + + rewriter.replaceOp(loadOp, res); + return success(); + } + + LogicalResult scalarizeWithLoop(triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + auto loc = storeOp.getLoc(); + auto vecTy = dyn_cast( + getTypeConverter()->convertType(storeOp.getValue().getType())); + + auto ptrs = storeOp.getPtr(); + auto mask = storeOp.getMask(); + auto vals = storeOp.getValue(); + auto cache = storeOp.getCache(); + auto evict = storeOp.getEvict(); + + // Create some reused constants. + Value zeroIdx = rewriter.create(loc, 0); + Value oneIdx = rewriter.create(loc, 1); + + // Alloca is inserted similar to the load case. + Operation *allocaPoint = storeOp; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Store a tensor of pointers, mask, and values into a temp buf if we can't + // compute them in a loop. + Value tmpPtrs = + maybeStoreVecToTempBuf(loc, ptrs, zeroIdx, allocaPoint, rewriter); + Value tmpMask = + maybeStoreVecToTempBuf(loc, mask, zeroIdx, allocaPoint, rewriter); + Value tmpVals = + maybeStoreVecToTempBuf(loc, vals, zeroIdx, allocaPoint, rewriter); + + // Create for-loops to iterate through all vector dimensions. + SmallVector forOps; + SmallVector ivs; + for (int64_t i = 0; i < vecTy.getRank(); ++i) { + Value upperBound = + rewriter.create(loc, vecTy.getShape()[i]); + auto forOp = + rewriter.create(loc, zeroIdx, upperBound, oneIdx); + forOps.push_back(forOp); + ivs.push_back(forOp.getInductionVar()); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + // Compute or load scalar args. + DenseMap valMap; + Value scalarPtr = + computeOrLoadScalarValue(ptrs, tmpPtrs, ivs, rewriter, valMap); + Value scalarMask = + computeOrLoadScalarValue(mask, tmpMask, ivs, rewriter, valMap); + Value scalarVal = + computeOrLoadScalarValue(vals, tmpVals, ivs, rewriter, valMap); + + if (!mask) { + // Regular store case. + rewriter.create(loc, scalarPtr, scalarVal, cache, evict); + } else { + // Conditional store case + rewriter.create(loc, scalarMask, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, scalarPtr, scalarVal, cache, evict); + builder.create(loc); + }); + } + + rewriter.eraseOp(storeOp); + return success(); + } + protected: ModuleAxisInfoAnalysis &axisAnalysis; ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; + bool genScalarLoops; }; struct LoadOpConversion : public MemoryOpConversion { @@ -281,6 +676,13 @@ struct LoadOpConversion : public MemoryOpConversion { auto loc = loadOp.getLoc(); auto vecTy = dyn_cast(getTypeConverter()->convertType(loadOp.getType())); + + // We want to avoid a code explosion when scalarize loads of big vectors, + // so try to build a scalar loop. + if (genScalarLoops && vecTy.getNumElements() >= 16 && + succeeded(scalarizeWithLoop(loadOp, rewriter))) + return success(); + auto ptrs = rewriter.getRemappedValue(loadOp.getPtr()); auto mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) : nullptr; @@ -423,11 +825,18 @@ struct StoreOpConversion : public MemoryOpConversion { assert(isa(storeOp.getValue().getType())); auto loc = storeOp.getLoc(); + auto tensorTy = dyn_cast(storeOp.getPtr().getType()); + + // We want to avoid a code explosion when scalarize stores of big vectors, + // so try to build a scalar loop. + if (genScalarLoops && tensorTy.getNumElements() >= 16 && + succeeded(scalarizeWithLoop(storeOp, rewriter))) + return success(); + auto ptrs = rewriter.getRemappedValue(storeOp.getPtr()); auto mask = storeOp.getMask() ? rewriter.getRemappedValue(storeOp.getMask()) : nullptr; auto vals = rewriter.getRemappedValue(storeOp.getValue()); - auto tensorTy = dyn_cast(storeOp.getPtr().getType()); auto ptrTy = tensorTy.getElementType(); auto cache = storeOp.getCache(); auto evict = storeOp.getEvict(); @@ -468,6 +877,7 @@ class MemoryOpConversionTarget : public ConversionTarget { addLegalDialect(); addLegalDialect(); addLegalDialect(); + addLegalDialect(); addLegalDialect(); addLegalDialect(); addLegalOp(); @@ -483,10 +893,12 @@ class MemoryOpConversionTarget : public ConversionTarget { }; struct ConvertMemoryOps - : public triton::impl::ConvertMemoryOpsBase { - using ConvertMemoryOpsBase::ConvertMemoryOpsBase; + : public triton::cpu::impl::ConvertMemoryOpsBase { + ConvertMemoryOps() = default; - ConvertMemoryOps() : ConvertMemoryOpsBase() {} + ConvertMemoryOps(bool useScalarLoops) { + this->useScalarLoops = useScalarLoops; + } void runOnOperation() override { MLIRContext *context = &getContext(); @@ -498,9 +910,9 @@ struct ConvertMemoryOps TritonToTritonCPUTypeConverter pointerConverter; RewritePatternSet patterns(context); patterns.add(axisInfoAnalysis, shapeInfoAnalysis, - pointerConverter, context); + pointerConverter, useScalarLoops, context); patterns.add(axisInfoAnalysis, shapeInfoAnalysis, - pointerConverter, context); + pointerConverter, useScalarLoops, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); @@ -517,6 +929,11 @@ std::unique_ptr> createConvertMemoryOps() { return std::make_unique(); } +std::unique_ptr> +createConvertMemoryOps(bool useScalarLoops) { + return std::make_unique(useScalarLoops); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 8f3608384aaf..3680bb20ace6 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -24,9 +24,10 @@ namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { using namespace mlir::triton; - m.def("add_convert_memory_ops", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); - }); + m.def("add_convert_memory_ops", + [](mlir::PassManager &pm, bool useScalarLoops) { + pm.addPass(mlir::triton::cpu::createConvertMemoryOps(useScalarLoops)); + }); m.def("add_convert_ptr_ops", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertPtrOps()); }); From b46f0858cff585e99a27aa20d6cba13a1a1dda73 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 21 Aug 2024 17:40:16 -0500 Subject: [PATCH 094/165] Fix typo. (#122) Signed-off-by: Ilya Enkovich --- python/tutorials/05-layer-norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index 76693a85b7f4..bc5716b8792c 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -292,7 +292,7 @@ def backward(ctx, dy): layer_norm = LayerNorm.apply device = triton.runtime.driver.active.get_current_target().backend # Torch doesn't support operations in float16 on CPU so use float32 instead -dtype = torch.float32 if device == 'cpu' else torch.flaot16 +dtype = torch.float32 if device == 'cpu' else torch.float16 def test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE): From f83cece037b21a87999149acdefe5cfbb1ba54ef Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 21 Aug 2024 19:16:20 -0500 Subject: [PATCH 095/165] Add lit tests for load/store scalarization. (#121) Signed-off-by: Ilya Enkovich --- test/TritonCPU/scalarize-memory-ops.mlir | 113 +++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 test/TritonCPU/scalarize-memory-ops.mlir diff --git a/test/TritonCPU/scalarize-memory-ops.mlir b/test/TritonCPU/scalarize-memory-ops.mlir new file mode 100644 index 000000000000..f62bbb5765f7 --- /dev/null +++ b/test/TritonCPU/scalarize-memory-ops.mlir @@ -0,0 +1,113 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops=use-scalar-loops=true -cse -canonicalize | FileCheck %s + +// Convert strided masked load and store to loops. Pointer and mask should be scalarized. +// TODO: There is an optimization opportunity to fuse loops. +// TODO: There is an optimization opportunity to reuse temp buffers. + +// CHECK-LABEL: @strided_masked_load_store +// CHECK: %[[ALLOCA1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> +// CHECK-NEXT: scf.for %[[IV1:.*]] = %c0 to %c128 step %c1 { +// CHECK-NEXT: %[[IV1_I32:.*]] = arith.index_castui %[[IV1]] : index to i32 +// CHECK-NEXT: %[[IDX1:.*]] = arith.muli %[[IV1_I32]], %c3_i32 : i32 +// CHECK-NEXT: %[[PTR1:.*]] = tt.addptr %arg0, %[[IDX1]] : !tt.ptr, i32 +// CHECK-NEXT: %[[MASK1:.*]] = arith.cmpi slt, %[[IDX1]], %arg2 : i32 +// CHECK-NEXT: scf.if %[[MASK1]] { +// CHECK-NEXT: %[[VAL1:.*]] = tt.load %[[PTR1]] : !tt.ptr +// CHECK-NEXT: memref.store %[[VAL1]], %[[ALLOCA1]][%[[IV1]]] : memref<128xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: memref.store %{{.*}}, %[[ALLOCA1]][%[[IV1]]] : memref<128xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: %[[VEC_VAL:.*]] = vector.transfer_read %[[ALLOCA1]][%c0], %{{.*}} {in_bounds = [true]} : memref<128xf32>, vector<128xf32> +// CHECK-NEXT: %[[ALLOCA2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> +// CHECK-NEXT: vector.transfer_write %[[VEC_VAL]], %[[ALLOCA2]][%c0] {in_bounds = [true]} : vector<128xf32>, memref<128xf32> +// CHECK-NEXT: scf.for %[[IV2:.*]] = %c0 to %c128 step %c1 { +// CHECK-NEXT: %[[IV2_I32:.*]] = arith.index_castui %[[IV2]] : index to i32 +// CHECK-NEXT: %[[IDX2:.*]] = arith.muli %[[IV2_I32]], %c3_i32 : i32 +// CHECK-NEXT: %[[PTR2:.*]] = tt.addptr %arg1, %[[IDX2]] : !tt.ptr, i32 +// CHECK-NEXT: %[[MASK2:.*]] = arith.cmpi slt, %[[IDX2]], %arg2 : i32 +// CHECK-NEXT: %[[VAL2:.*]] = memref.load %[[ALLOCA2]][%[[IV2]]] : memref<128xf32> +// CHECK-NEXT: scf.if %[[MASK2]] { +// CHECK-NEXT: tt.store %[[PTR2]], %[[VAL2]] : !tt.ptr +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { + tt.func public @strided_masked_load_store(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) { + %cst = arith.constant dense<1.000000e+00> : tensor<128xf32> + %cst_0 = arith.constant dense<3> : tensor<128xi32> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = arith.muli %0, %cst_0 : tensor<128xi32> + %2 = tt.splat %arg2 : i32 -> tensor<128xi32> + %3 = arith.cmpi slt, %1, %2 : tensor<128xi32> + %4 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %5 = tt.addptr %4, %1 : tensor<128x!tt.ptr>, tensor<128xi32> + %6 = tt.load %5, %3, %cst : tensor<128x!tt.ptr> + %7 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %8 = tt.addptr %7, %1 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %8, %6, %3 : tensor<128x!tt.ptr> + tt.return + } +} + +// ----- + +// Convert indirect masked load and store. Pointer and mask are bufferized. +// TODO: There is an optimization opportunity to fuse loops. +// TODO: There is an optimization opportunity to reuse temp buffers. + +// CHECK-LABEL: @indirect_masked_load_store +// CHECK: %[[ALLOCA_VALS1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> +// CHECK-NEXT: %[[ALLOCA_PTRS1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi64> +// CHECK-NEXT: vector.transfer_write %{{.*}}, %[[ALLOCA_PTRS1]][%c0] {in_bounds = [true]} : vector<128xi64>, memref<128xi64> +// CHECK-NEXT: %[[EXT_MASK:.*]] = arith.extui %{{.*}} : vector<128xi1> to vector<128xi8> +// CHECK-NEXT: %[[ALLOCA_MASK1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi8> +// CHECK-NEXT: vector.transfer_write %[[EXT_MASK]], %[[ALLOCA_MASK1]][%c0] {in_bounds = [true]} : vector<128xi8>, memref<128xi8> +// CHECK-NEXT: scf.for %[[IV1:.*]] = %c0 to %c128 step %c1 { +// CHECK-NEXT: %[[PTR1_INT:.*]] = memref.load %[[ALLOCA_PTRS1]][%[[IV1]]] : memref<128xi64> +// CHECK-NEXT: %[[PTR1:.*]] = tt.int_to_ptr %[[PTR1_INT]] : i64 -> !tt.ptr +// CHECK-NEXT: %[[MASK1_I8:.*]] = memref.load %[[ALLOCA_MASK1]][%[[IV1]]] : memref<128xi8> +// CHECK-NEXT: %[[MASK1:.*]] = arith.trunci %[[MASK1_I8]] : i8 to i1 +// CHECK-NEXT: scf.if %[[MASK1]] { +// CHECK-NEXT: %[[VAL1:.*]] = tt.load %[[PTR1]] : !tt.ptr +// CHECK-NEXT: memref.store %[[VAL1]], %[[ALLOCA_VALS1]][%[[IV1]]] : memref<128xf32> +// CHECK-NEXT: } else { +// CHECK-NEXT: memref.store %{{.*}}, %[[ALLOCA_VALS1]][%[[IV1]]] : memref<128xf32> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: %[[VEC_VAL:.*]] = vector.transfer_read %[[ALLOCA_VALS1]][%c0], %{{.*}} {in_bounds = [true]} : memref<128xf32>, vector<128xf32> +// CHECK: %[[ALLOCA_PTRS2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi64> +// CHECK-NEXT: vector.transfer_write %{{.*}}, %[[ALLOCA_PTRS2]][%c0] {in_bounds = [true]} : vector<128xi64>, memref<128xi64> +// CHECK-NEXT: %[[ALLOCA_MASK2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi8> +// CHECK-NEXT: vector.transfer_write %[[EXT_MASK]], %[[ALLOCA_MASK2]][%c0] {in_bounds = [true]} : vector<128xi8>, memref<128xi8> +// CHECK-NEXT: %[[ALLOCA_VALS2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> +// CHECK-NEXT: vector.transfer_write %[[VEC_VAL]], %[[ALLOCA_VALS2]][%c0] {in_bounds = [true]} : vector<128xf32>, memref<128xf32> +// CHECK-NEXT: scf.for %[[IV2:.*]] = %c0 to %c128 step %c1 { +// CHECK-NEXT: %[[PTR2_INT:.*]] = memref.load %[[ALLOCA_PTRS2]][%[[IV2]]] : memref<128xi64> +// CHECK-NEXT: %[[PTR2:.*]] = tt.int_to_ptr %[[PTR1_INT]] : i64 -> !tt.ptr +// CHECK-NEXT: %[[MASK2_I8:.*]] = memref.load %[[ALLOCA_MASK2]][%[[IV2]]] : memref<128xi8> +// CHECK-NEXT: %[[MASK2:.*]] = arith.trunci %[[MASK2_I8]] : i8 to i1 +// CHECK-NEXT: %[[VAL2:.*]] = memref.load %[[ALLOCA_VALS2]][%[[IV2]]] : memref<128xf32> +// CHECK-NEXT: scf.if %[[MASK2]] { +// CHECK-NEXT: tt.store %[[PTR2]], %[[VAL2]] : !tt.ptr +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { + tt.func public @indirect_masked_load_store(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32) { + %cst = arith.constant dense<0.000000e+00> : tensor<128xf32> + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + %1 = tt.splat %arg2 : !tt.ptr -> tensor<128x!tt.ptr> + %2 = tt.addptr %1, %0 : tensor<128x!tt.ptr>, tensor<128xi32> + %3 = tt.load %2 : tensor<128x!tt.ptr> + %4 = tt.splat %arg3 : i32 -> tensor<128xi32> + %5 = arith.cmpi slt, %3, %4 : tensor<128xi32> + %6 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> + %7 = tt.addptr %6, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + %8 = tt.load %7, %5, %cst : tensor<128x!tt.ptr> + %9 = tt.splat %arg1 : !tt.ptr -> tensor<128x!tt.ptr> + %10 = tt.addptr %9, %3 : tensor<128x!tt.ptr>, tensor<128xi32> + tt.store %10, %8, %5 : tensor<128x!tt.ptr> + tt.return + } +} From 544ccf2d720d1b0e929acce83f15de203afae3fd Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Wed, 21 Aug 2024 17:27:20 -0700 Subject: [PATCH 096/165] [cpu][easy] Fix compiler error on clang (#120) --- third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index dd9bba70ff1d..02a458986269 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -392,7 +392,7 @@ struct MemoryOpConversion : public OpConversionPattern { auto resMemRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); Value resMemRef = createAlloca(loc, resMemRefTy, allocaPoint, rewriter); - bool storeOtherInLoop = mask; + bool storeOtherInLoop = static_cast(mask); if (other && !canComputeScalarValue(other)) { SmallVector indices(vecTy.getRank(), zeroIdx); rewriter.create( From 439867bb1e6ab498ab4d6ded44b60cee3fbfd44b Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 22 Aug 2024 19:50:01 +0000 Subject: [PATCH 097/165] Offload a part of masks optimization to the canonicalizer. Signed-off-by: Ilya Enkovich --- test/TritonCPU/optimize-masks.mlir | 40 ++++++++++++++- third_party/cpu/backend/compiler.py | 1 + .../lib/TritonCPUTransforms/OptimizeMasks.cpp | 49 +++++-------------- 3 files changed, 51 insertions(+), 39 deletions(-) diff --git a/test/TritonCPU/optimize-masks.mlir b/test/TritonCPU/optimize-masks.mlir index 2a013699cbac..5ab482a565a6 100644 --- a/test/TritonCPU/optimize-masks.mlir +++ b/test/TritonCPU/optimize-masks.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -triton-cpu-optimize-masks | FileCheck %s +// RUN: triton-opt %s -split-input-file -triton-cpu-optimize-masks -canonicalize | FileCheck %s // Convert strided masked loads to scalar loads. @@ -33,3 +33,41 @@ module { tt.return } } + +// ----- + +// Replace masked load with a regular load and optimize out arith.select. + +// CHECK-LABEL: @optimize_select +// CHECK: vector.load +// CHECK-NEXT: arith.addf +// CHECK-NEXT: arith.addf +// CHECK-NEXT: scf.yield + +module { + tt.func public @optimize_select(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32 {tt.divisibility = 16 : i32}) { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %cst_1 = arith.constant dense<1.000000e+00> : vector<16xf32> + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : vector<16xf32> + %0 = vector.splat %arg2 : vector<16xi32> + %1 = scf.for %arg3 = %c0_i32 to %arg2 step %c16_i32 iter_args(%arg4 = %cst_2) -> (vector<16xf32>) : i32 { + %3 = vector.splat %arg3 : vector<16xi32> + %4 = arith.addi %3, %cst_0 : vector<16xi32> + %5 = arith.cmpi slt, %4, %0 : vector<16xi32> + %6 = tt.addptr %arg0, %arg3 : !tt.ptr, i32 + %7 = triton_cpu.ptr_to_memref %6 : -> memref<16xf32> + %8 = vector.maskedload %7[%c0], %5, %cst_2 : memref<16xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> + %9 = arith.addf %8, %cst_1 : vector<16xf32> + %10 = arith.select %5, %9, %cst_2 : vector<16xi1>, vector<16xf32> + %11 = arith.addf %arg4, %10 : vector<16xf32> + scf.yield %11 : vector<16xf32> + } + %2 = vector.multi_reduction , %1, %cst [0] : vector<16xf32> to f32 + tt.store %arg1, %2 : !tt.ptr + tt.return + } +} diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index cbc5d60b3957..54a375399192 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -119,6 +119,7 @@ def make_tttcir(self, mod, metadata, opt): pm = ir.pass_manager(mod.context) pm.enable_debug() cpu.passes.ttcpuir.add_optimize_masks(pm) + passes.common.add_canonicalizer(pm) convert_bf16_dot_product = self.cpu_arch == "aarch64" and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features if convert_bf16_dot_product: use_horizontal_sum = os.getenv("TRITON_CPU_DOT_PROD_HORIZ_SUM", "1") == "1" diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp index d113e6671531..332ed5c97c7b 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp @@ -284,11 +284,7 @@ AffineExpr buildMinOrMaxExpr(Value val, bool isSigned, bool isMax, // Check if vector mask is all-ones by checking compared values ranges. // Only simplest cases are covered here, so affine expression is used // to represent a range for now. -bool isAlwaysAllOnes(Value mask) { - auto maskDef = mask.getDefiningOp(); - if (!maskDef) - return false; - +bool isAlwaysAllOnes(arith::CmpIOp maskDef) { auto pred = maskDef.getPredicate(); if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) return false; @@ -323,30 +319,16 @@ bool isAlwaysAllOnes(Value mask) { return false; } -struct OptimizeMaskedLoad : public OpRewritePattern { +struct OptimizeMask : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(vector::MaskedLoadOp op, + LogicalResult matchAndRewrite(arith::CmpIOp op, PatternRewriter &rewriter) const override { - if (!isAlwaysAllOnes(op.getMask())) + if (!isAlwaysAllOnes(op)) return failure(); - rewriter.replaceOpWithNewOp(op, op.getType(), op.getBase(), - op.getIndices()); - return success(); - } -}; - -struct OptimizeMaskedStore : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::MaskedStoreOp op, - PatternRewriter &rewriter) const override { - if (!isAlwaysAllOnes(op.getMask())) - return failure(); - - rewriter.replaceOpWithNewOp(op, op.getValueToStore(), - op.getBase(), op.getIndices()); + rewriter.replaceOpWithNewOp( + op, op.getType(), rewriter.getOneAttr(op.getType())); return success(); } }; @@ -362,20 +344,11 @@ struct OptimizeMasks // TODO: This pass optimizes out masks applying a set of very strict // patterns. We should use more generic range and divisibility analysis // to cover more cases and remove dependency on other transformations. - RewritePatternSet patterns1(context); - patterns1.add(context); - if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns1)))) - return signalPassFailure(); - - RewritePatternSet patterns2(context); - patterns2.add(context); - if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns2)))) - return signalPassFailure(); - - RewritePatternSet patterns3(context); - patterns3.add(context); - patterns3.add(context); - if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns3)))) + RewritePatternSet patterns(context); + patterns.add(context); + patterns.add(context); + patterns.add(context); + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) return signalPassFailure(); // TODO: if masks removal failed for loads/stores in a for-loop, we might From 25d5d3e825459bf2f2ed4bdf17141e7217e6fb92 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 29 Aug 2024 08:24:18 -0700 Subject: [PATCH 098/165] Implement get_module_map for cpu backend This is required by the BaseBackend. --- third_party/cpu/backend/compiler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 54a375399192..f9ec13ec5042 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -5,7 +5,8 @@ from pathlib import Path from dataclasses import dataclass -from typing import Any, Tuple +from types import ModuleType +from typing import Any, Dict, Tuple from triton._C.libtriton import cpu, ir, llvm, passes from triton.backends.compiler import BaseBackend, GPUTarget @@ -73,6 +74,10 @@ def get_codegen_implementation(self): codegen_fns = {"min_dot_size": min_dot_size(self.target)} return codegen_fns + def get_module_map(self) -> Dict[str, ModuleType]: + from triton.language.extra.cpu import libdevice + return {"triton.language.extra.libdevice": libdevice} + def load_dialects(self, ctx): cpu.load_dialects(ctx) From eb759af2e0fd57e44343785c46de41b9218d7c51 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 29 Aug 2024 19:46:31 -0400 Subject: [PATCH 099/165] Make CPU runtime lib lookup work for Python 3.8 (#129) The version compatibility pain continues... PyTorch still supports 3.8 so we need this --- third_party/cpu/backend/driver.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index f17a0cba9f30..44d980e01987 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -14,7 +14,11 @@ _dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") # for locating libTritonCPURuntime -_triton_C_dir = importlib.resources.files(triton).joinpath("_C") +try: + _triton_C_dir = importlib.resources.files(triton).joinpath("_C") +except AttributeError: + # resources.files() doesn't exist for Python < 3.9 + _triton_C_dir = importlib.resources.path(triton, "_C").__enter__() include_dirs = [os.path.join(_dirname, "include")] library_dirs = [os.path.join(_dirname, "lib"), _triton_C_dir] From ff4b3475714c8daa4207e838d5ab1d085ad62702 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 29 Aug 2024 21:53:07 -0400 Subject: [PATCH 100/165] Implement device_assert (#126) This implements device_assert using a similar approach to @minjang's implementation of device_print -- we create a new triton_cpu.assert op that takes a vector which we then lower to our own `triton_assert` implementation. One notable difference between our implementation vs the GPU one is that `abort()` on the CPU aborts the entire program, not just a single thread. In contrast, `__assert_fail` on the GPU seems to stop execution of the current thread, but other threads still continue execution. Thus we only get ~one assert message even if the assert condition is false for multiple threads, while the same program on the GPU gives us assert messages for all threads. I suppose there are some ways around this, but for now, let's go with the simplest implementation. --- .../Dialect/TritonCPU/IR/TritonCPUOps.td | 11 +++ .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 78 ++++++++++++++++--- .../lib/TritonToTritonCPU/ConvertDebugOps.cpp | 25 ++++++ third_party/cpu/runtime/cpu_runtime.cpp | 12 ++- 4 files changed, 114 insertions(+), 12 deletions(-) diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index 6bcca9ec0d5b..bbd832c44621 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -96,4 +96,15 @@ def TTC_PrintOp : TTC_Op<"print", [MemoryEffects<[MemWrite]>]> { let hasVerifier = 1; } +def TT_AssertOp : TTC_Op<"assert", [MemoryEffects<[MemWrite]>]> { + let summary = "For correctness checking"; + let description = [{ + Takes a condition tensor, a message string, a file string, a function string, and a line number. + If the condition is false, the message is printed, and the program is aborted. + }]; + let arguments = (ins I1:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line); + let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)"; +} + + #endif diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index c60da23b765a..9a39ed25b78d 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -102,6 +102,12 @@ LLVM::LLVMFuncOp getPrintFuncDecl(ConversionPatternRewriter &rewriter, funcType); } +static StringRef makeNullTerminatedString(StringRef s) { + llvm::SmallString<64> ss(s); + ss.push_back(0); + return ss; +} + void llPrintf(StringRef prefix, std::array pid, std::optional arg, ConversionPatternRewriter &rewriter, bool hex = false) { @@ -135,10 +141,8 @@ void llVectorPrint(std::array pid, StringRef prefix, Value ptr, assert(!prefix.empty()); auto loc = UnknownLoc::get(rewriter.getContext()); - llvm::SmallString<64> prefixStr(prefix); - prefixStr.push_back('\0'); - Value prefixValue = - LLVM::addStringToModule(loc, rewriter, "vectorPrintPrefix_", prefixStr); + Value prefixValue = LLVM::addStringToModule( + loc, rewriter, "vectorPrintPrefix_", makeNullTerminatedString(prefix)); SmallVector allArgs; for (auto elem : pid) @@ -162,6 +166,10 @@ bool usePrintf(triton::cpu::PrintOp op) { return (oprType.isIntOrIndexOrFloat() || isa(oprType)); } +Value getPid(Operation *op, int axis) { + return getProgramId(op->getParentOfType(), axis); +}; + struct PrintOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -170,10 +178,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto getPid = [&](int axis) { - return getProgramId(op->getParentOfType(), axis); - }; - std::array pid = {getPid(0), getPid(1), getPid(2)}; + std::array pid = {getPid(op, 0), getPid(op, 1), getPid(op, 2)}; if (usePrintf(op)) { if (op.getNumOperands() == 0) { @@ -214,6 +219,61 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { } }; +struct AssertOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::cpu::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctx = rewriter.getContext(); + auto typeConverter = getTypeConverter(); + Value message = + LLVM::addStringToModule(loc, rewriter, "assertMessage_", + makeNullTerminatedString(adaptor.getMessage())); + Value file = + LLVM::addStringToModule(loc, rewriter, "assertFile_", + makeNullTerminatedString(adaptor.getFile())); + Value func = + LLVM::addStringToModule(loc, rewriter, "assertFunc_", + makeNullTerminatedString(adaptor.getFunc())); + SmallVector args{getPid(op, 0), + getPid(op, 1), + getPid(op, 2), + op.getCondition(), + message, + file, + i32_val(adaptor.getLine()), + func}; + call(getAssertFuncDecl(rewriter), args); + rewriter.eraseOp(op); + return success(); + } + + static LLVM::LLVMFuncOp + getAssertFuncDecl(ConversionPatternRewriter &rewriter) { + auto moduleOp = + rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName = "triton_assert"; + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *ctx = rewriter.getContext(); + SmallVector argsType{i32_ty, i32_ty, i32_ty, i1_ty, + ptr_ty(ctx), ptr_ty(ctx), i32_ty, ptr_ty(ctx)}; + + auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); + } +}; + using BarrierOp = mlir::gpu::BarrierOp; // This is part of the DebugOps pass because gpu::barrier is generated by @@ -247,8 +307,8 @@ struct DebugOpsToLLVM RewritePatternSet patterns(context); patterns.add(typeConverter); + patterns.add(typeConverter); patterns.add(typeConverter); - // patterns.add(typeConverter); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) { return signalPassFailure(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp index cf6e6704bc28..59806592d5d8 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -32,7 +32,10 @@ class DebugOpsConversionTarget : public ConversionTarget { addLegalDialect(); addLegalDialect(); + addLegalOp(); + addIllegalOp(); + addIllegalOp(); } }; @@ -65,6 +68,27 @@ struct PrintOpConversion : public OpConversionPattern { } }; +struct AssertOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value acc = rewriter.create(loc, i1_ty, + rewriter.getOneAttr(i1_ty)); + Value condition = rewriter.getRemappedValue(op.getCondition()); + SmallVector dimsToReduce( + cast(condition.getType()).getRank(), true); + condition = rewriter.create( + loc, condition, acc, dimsToReduce, vector::CombiningKind::AND); + rewriter.replaceOpWithNewOp( + op, condition, op.getMessage(), op.getFile(), op.getFunc(), + op.getLine()); + return success(); + } +}; + struct ConvertDebugOps : public triton::impl::ConvertDebugOpsBase { using ConvertDebugOpsBase::ConvertDebugOpsBase; @@ -79,6 +103,7 @@ struct ConvertDebugOps DebugOpsConversionTarget convTarget(*context, typeConverter); RewritePatternSet patterns(context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index 3d232ddb2530..888eeb028a63 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -148,9 +149,14 @@ void printFormattedElement(std::stringstream &ss, void *vec, size_t index, extern "C" { -EXPORT void triton_assert(bool cond, char *c) { - if (!cond) - fprintf(stderr, "%s\n", c); +EXPORT void triton_assert(int32_t pid0, int32_t pid1, int32_t pid2, bool cond, + const char *message, const char *file, int32_t line, + const char *function) { + if (cond) + return; + fprintf(stderr, "%s:%u: %s: block: [%u, %u, %u] Assertion `%s` failed.\n", + file, line, function, pid0, pid1, pid2, message); + abort(); } // Print the pid prefix like the GPU ad interpreter. And vectors are printed From 75f512e455fcb40f16f0da34f2343078be23deb7 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 29 Aug 2024 21:53:42 -0400 Subject: [PATCH 101/165] Implement isnan, isinf, signbit (#127) Since these aren't natively available in sleef / libmvec, let's implement them natively in Triton. They were implemented via a @builtin wrapping a @jit function because operations like `get_int_dtype` don't work under @jit. --- python/test/unit/cpu/test_libdevice.py | 20 +++++- python/triton/language/extra/cpu/__init__.py | 3 - python/triton/language/extra/cpu/libdevice.py | 69 +++++++++++++++++++ 3 files changed, 87 insertions(+), 5 deletions(-) diff --git a/python/test/unit/cpu/test_libdevice.py b/python/test/unit/cpu/test_libdevice.py index 5a37ec9af21d..07e0d5ff1ccf 100644 --- a/python/test/unit/cpu/test_libdevice.py +++ b/python/test/unit/cpu/test_libdevice.py @@ -6,6 +6,8 @@ import triton.language as tl from triton.language.extra import libdevice +torch.manual_seed(0) + def is_interpreter(): return os.environ.get('TRITON_INTERPRET', '0') == '1' @@ -22,7 +24,7 @@ def is_cpu(): @pytest.mark.parametrize("dtype_str", float_dtypes) @pytest.mark.parametrize("math_fn", [ "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "expm1", "floor", - "log", "log1p", "log2", "log10", "rsqrt", "sin", "sinh", "sqrt", "tan", "tanh" + "isnan", "isinf", "log", "log1p", "log2", "log10", "rsqrt", "signbit", "sin", "sinh", "sqrt", "tan", "tanh" ]) @pytest.mark.parametrize("size", [1, 4, 16, 64]) def test_libdevice(dtype_str, math_fn, size, device): @@ -37,9 +39,23 @@ def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): tl.store(dst + idxs, y) src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) + if math_fn == "acosh": src = src.abs() + 1 - res = torch.empty(src.shape, dtype=getattr(torch, dtype_str), device=device) + if math_fn == "isnan" or math_fn == "isinf": + indices = torch.randint(low=0, high=size, size=(size // 2, ), device=device) + for i in indices: + if math_fn == "isnan": + src[i] = float("nan") + else: + src[i] = float(("+" if i % 2 else "-") + "inf") + + if math_fn in ["isnan", "isinf", "signbit"]: + out_dtype = torch.bool + else: + out_dtype = getattr(torch, dtype_str) + + res = torch.empty(src.shape, dtype=out_dtype, device=device) kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size) if math_fn == "cbrt": ref = src.pow(1 / 3) diff --git a/python/triton/language/extra/cpu/__init__.py b/python/triton/language/extra/cpu/__init__.py index 229b57d87d65..e69de29bb2d1 100644 --- a/python/triton/language/extra/cpu/__init__.py +++ b/python/triton/language/extra/cpu/__init__.py @@ -1,3 +0,0 @@ -from . import libdevice - -__all__ = ["libdevice"] diff --git a/python/triton/language/extra/cpu/libdevice.py b/python/triton/language/extra/cpu/libdevice.py index 9dbb4d682d42..e442d0234d0f 100644 --- a/python/triton/language/extra/cpu/libdevice.py +++ b/python/triton/language/extra/cpu/libdevice.py @@ -1,4 +1,7 @@ +import triton.language as tl from triton.language import core +from triton.language.core import builtin +from triton import jit @core.extern @@ -119,3 +122,69 @@ def tan(arg0, _builder=None): @core.extern def tanh(arg0, _builder=None): return core.tensor(_builder.create_tanh(arg0.handle), arg0.type) + + +@jit +def _const(v, dtype): + """ + Create a tensor with a single value of type :dtype. + """ + return tl.full((1, ), v, dtype) + + +@jit +def _is_special_float(arg0, uint_dtype, kind: tl.constexpr): + # By default, Triton assumes constexprs are int32. Thus, when we do operations with constants, + # we end up auto-promoting smaller integer types to int32, which is undesirable. Thus we + # explicitly cast them to our desired type here. + one = _const(1, uint_dtype) + zero = _const(0, uint_dtype) + + bitwidth: tl.constexpr = arg0.dtype.primitive_bitwidth + exponent_width: tl.constexpr = bitwidth - 1 - arg0.dtype.fp_mantissa_width + mantissa_width: tl.constexpr = arg0.dtype.fp_mantissa_width + + uintval = arg0.to(uint_dtype, bitcast=True) + exponent = uintval << one >> _const(mantissa_width, uint_dtype) + one + exp_is_all_ones = exponent == (one << _const(exponent_width, uint_dtype)) - one + shifted_mantissa = uintval << _const(exponent_width, uint_dtype) + one + + if kind == "nan": + return exp_is_all_ones & (shifted_mantissa != zero) + elif kind == "inf": + return exp_is_all_ones & (shifted_mantissa == zero) + else: + raise ValueError(f"Unexpected kind {kind}") + + +@builtin +def isnan(arg0, _builder=None, _generator=None): + if not arg0.dtype.is_floating(): + raise ValueError("isnan expects a floating point type") + bitwidth = arg0.dtype.primitive_bitwidth + uint_dtype = tl.core.get_int_dtype(bitwidth, signed=False) + return _generator.call_JitFunction(_is_special_float, (arg0, uint_dtype, "nan"), kwargs={}) + + +@builtin +def isinf(arg0, _builder=None, _generator=None): + if not arg0.dtype.is_floating(): + raise ValueError("isinf expects a floating point type") + bitwidth = arg0.dtype.primitive_bitwidth + uint_dtype = tl.core.get_int_dtype(bitwidth, signed=False) + return _generator.call_JitFunction(_is_special_float, (arg0, uint_dtype, "inf"), kwargs={}) + + +@jit +def _signbit(arg0, uint_dtype: tl.constexpr): + bitwidth: tl.constexpr = arg0.dtype.primitive_bitwidth + return arg0.to(uint_dtype, bitcast=True) >> (bitwidth - 1) + + +@builtin +def signbit(arg0, _builder=None, _generator=None): + if not arg0.dtype.is_floating(): + raise ValueError("signbit expects a floating point type") + bitwidth = arg0.dtype.primitive_bitwidth + uint_dtype = tl.core.get_int_dtype(bitwidth, signed=False) + return _generator.call_JitFunction(_signbit, (arg0, uint_dtype), kwargs={}) From 58c7d5187269f687ca32c8d8f55ca7c2eb25c6a5 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Sat, 31 Aug 2024 09:33:16 -0400 Subject: [PATCH 102/165] Vendor sleef as a submodule (#130) --- .github/workflows/build-test.yml | 2 ++ .gitignore | 1 + .gitmodules | 3 +++ third_party/cpu/CMakeLists.txt | 13 ++++++++++++- third_party/cpu/backend/compiler.py | 9 +-------- third_party/sleef | 1 + 6 files changed, 20 insertions(+), 9 deletions(-) create mode 100644 .gitmodules create mode 160000 third_party/sleef diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index f675e052acf6..0c8edb01a299 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -56,6 +56,8 @@ jobs: steps: - name: Checkout repository uses: actions/checkout@v4 + with: + submodules: recursive - name: Install Python ${{ matrix.python }} uses: actions/setup-python@v5 diff --git a/.gitignore b/.gitignore index ec8a00351867..06de43f2a69f 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ python/triton*.egg-info/ python/triton/_C/*.pyd python/triton/_C/*.so +python/triton/_C/*.so.* python/triton/_C/*.dylib python/triton/_C/*.pdb python/triton/_C/*.exe diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 000000000000..b2b6bf04a546 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "sleef"] + path = third_party/sleef + url = https://github.com/shibatch/sleef diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index c0bbcbfca1f6..fd55642022e4 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -5,5 +5,16 @@ add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm) - add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) endif() + +add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) + +# Build and link sleef +set(SLEEF_BUILD_SHARED_LIBS ON CACHE BOOL "Build sleef shared lib" FORCE) +set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE) +set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE) +set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE) +set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will not be built." FORCE) +add_subdirectory("${CMAKE_SOURCE_DIR}/third_party/sleef" sleef) +# Override sleef's output directory with our own +set_target_properties(sleef PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index f9ec13ec5042..990f874cc262 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -215,14 +215,7 @@ def make_so(src, metadata, options): asm_path = os.path.join(tmpdir, "kernel.s") Path(asm_path).write_text(src) lib_dirs = cpu_driver.library_dirs - libs = ["gcc", "m", "TritonCPURuntime"] - # TRITON_CPU_USE_SLEEF=1 - use system libsleef - # TRITON_CPU_USE_SLEEF=path - use libsleef from the specified path - use_sleef = os.environ.get("TRITON_CPU_USE_SLEEF", "0") - if use_sleef != "0": - if os.path.isdir(use_sleef): - lib_dirs.append(use_sleef) - libs.append("sleef") + libs = ["gcc", "m", "TritonCPURuntime", "sleef"] so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs) with open(so, "rb") as f: return f.read() diff --git a/third_party/sleef b/third_party/sleef new file mode 160000 index 000000000000..93f04d869471 --- /dev/null +++ b/third_party/sleef @@ -0,0 +1 @@ +Subproject commit 93f04d869471ce4d007abaebb8c6a7bc62749f61 From ebada24eba5f5755ef74069423a5ed92f496721e Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Sun, 1 Sep 2024 08:47:17 -0400 Subject: [PATCH 103/165] Add test_debug_dump.py to CI (#131) --- .github/workflows/build-test.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 0c8edb01a299..f2c446076114 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -99,7 +99,8 @@ jobs: python/test/unit/runtime/test_cache.py \ python/test/unit/cpu/test_libdevice.py \ python/test/unit/cpu/test_libmvec.py \ - python/test/unit/cpu/test_opt.py + python/test/unit/cpu/test_opt.py \ + python/test/unit/test_debug_dump.py - name: Run lit tests run: | From 538ed7f2baf15f31c685fc4a06d2ae0e15f4d1d1 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Wed, 4 Sep 2024 15:55:28 -0400 Subject: [PATCH 104/165] Refactor MathToLibmvec pass (#135) Since the pass can generate libsleef calls too, having "libmvec" in the pass name is a bit of a misnomer, so I've renamed it accordingly. I've also switched from using a bool to using an enum to select between vector math libraries. This makes for cleaner code and makes it easy to extend to support additional libraries (e.g. MKL). Finally (and most significantly), I've changed how library function names are generated. Instead of having all the logic in the MathToLibmvecPass, I'm giving each library its own naming functor. This makes it easier to extend to accomodate each library's quirks, as following PRs will demonstrate. --- third_party/cpu/backend/compiler.py | 11 +- .../cpu/include/TritonCPUToLLVM/Passes.h | 8 +- .../cpu/include/TritonCPUToLLVM/Passes.td | 10 +- .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 2 +- .../{MathToLibmvec.cpp => MathToVecLib.cpp} | 196 ++++++++++-------- third_party/cpu/triton_cpu.cc | 11 +- 6 files changed, 131 insertions(+), 107 deletions(-) rename third_party/cpu/lib/TritonCPUToLLVM/{MathToLibmvec.cpp => MathToVecLib.cpp} (62%) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 990f874cc262..6d8cd42912d4 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -165,10 +165,13 @@ def make_llir(self, src, metadata, options): cpu.passes.ttcpuir.add_memory_op_to_llvmir(pm) cpu.passes.ttcpuir.add_atomic_ops_to_llvmir(pm) cpu.passes.ttcpuir.add_debug_ops_to_llvmir(pm) - use_sleef = os.environ.get("TRITON_CPU_USE_SLEEF", "0") != "0" - use_vec_math = os.environ.get("TRITON_CPU_USE_LIBMVEC", "1") != "0" - if (use_sleef or use_vec_math) and self.cpu_arch == "x86_64" and "avx512f" in self.cpu_features: - cpu.passes.ttcpuir.add_math_to_libmvec(pm, use_sleef) + vec_lib = None + if os.environ.get("TRITON_CPU_USE_LIBMVEC", "1") != "0": + vec_lib = cpu.passes.ttcpuir.VecLib.libmvec + if os.environ.get("TRITON_CPU_USE_SLEEF", "0") != "0": + vec_lib = cpu.passes.ttcpuir.VecLib.libsleef + if vec_lib is not None and self.cpu_arch == "x86_64" and "avx512f" in self.cpu_features: + cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib) passes.convert.add_math_to_llvmir(pm) cpu.passes.ttcpuir.add_math_to_libm(pm) cpu.passes.ttcpuir.add_vector_to_llvmir(pm, options.enable_fast_math) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index fb366d6f82bc..6e9892d00206 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -17,6 +17,11 @@ template class OperationPass; namespace triton { namespace cpu { +enum class VecLib { + Mvec, + Sleef, +}; + #define GEN_PASS_DECL #include "cpu/include/TritonCPUToLLVM/Passes.h.inc" @@ -26,9 +31,8 @@ std::unique_ptr> createGetProgramIdOpToLLVMPass(); std::unique_ptr> createLowerMultiReductionPass(); std::unique_ptr> createAtomicOpsToLLVMPass(); std::unique_ptr> createDebugOpsToLLVMPass(); -std::unique_ptr> createMathToLibmvecPass(); std::unique_ptr> -createMathToLibmvecPass(bool use_sleef); +createMathToVecLibPass(VecLib lib = VecLib::Sleef); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUToLLVM/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index 6b3e0a8bd9d0..3ee08d9968b2 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -66,16 +66,16 @@ def DebugOpsToLLVM : Pass<"triton-cpu-debug-ops-to-llvm", "mlir::ModuleOp"> { "mlir::triton::TritonDialect"]; } -def MathToLibmvec : Pass<"triton-cpu-math-to-libmvec", "mlir::ModuleOp"> { +def MathToVecLib : Pass<"triton-cpu-math-to-vec-lib", "mlir::ModuleOp"> { let summary = "Convert vector math operations to vector libm or sleef calls."; let description = [{ }]; - let constructor = "mlir::triton::cpu::createMathToLibmvecPass()"; + let constructor = "mlir::triton::cpu::createMathToVecLibPass()"; let options = [ - Option<"use_sleef", "use-sleef", - "bool", /*default*/"false", - "Use sleef library for vector math instead of libmvec.">, + Option<"lib", "lib", + "mlir::triton::cpu::VecLib", /*default*/"mlir::triton::cpu::VecLib::Sleef", + "Library to use for vector math (libsleef or libmvec).">, ]; let dependentDialects = ["mlir::vector::VectorDialect", diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt index b4c6372132be..5448d81937f4 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -4,7 +4,7 @@ add_triton_library(TritonCPUToLLVM FuncOpToLLVM.cpp GetProgramIdOpToLLVM.cpp LowerMultiReduction.cpp - MathToLibmvec.cpp + MathToVecLib.cpp MemoryOpToLLVM.cpp TypeConverter.cpp Utility.cpp diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp similarity index 62% rename from third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp rename to third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index 67c666904abd..782026146979 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToLibmvec.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -13,7 +13,7 @@ namespace mlir { namespace triton { namespace cpu { -#define GEN_PASS_DEF_MATHTOLIBMVEC +#define GEN_PASS_DEF_MATHTOVECLIB #include "cpu/include/TritonCPUToLLVM/Passes.h.inc" } // namespace cpu } // namespace triton @@ -52,7 +52,7 @@ template struct VecOpToFp32 : public OpRewritePattern { } }; -// Decompose vector operation to singe-dimensional vector operations +// Decompose vector operation to single-dimensional vector operations // with a native AVX512 vector size. template struct DecomposeToNativeVecs : public OpRewritePattern { @@ -134,32 +134,73 @@ struct DecomposeToNativeVecs : public OpRewritePattern { } }; -template -struct VecOpToLibmvecCall : public OpRewritePattern { +using GetVecFnNameFn = std::function; + +class MvecNameGenerator { public: - using OpRewritePattern::OpRewritePattern; + explicit MvecNameGenerator(StringRef baseName) : baseName(baseName) {} - VecOpToLibmvecCall(MLIRContext *context, StringRef fp32FnBaseName, - StringRef fp64FnBaseName, bool use_sleef) - : OpRewritePattern(context) { - this->fp32FnBaseName = fp32FnBaseName; - this->fp64FnBaseName = fp64FnBaseName; - this->use_sleef = use_sleef; + std::string operator()(unsigned bitwidth, unsigned numel, + ValueRange operands) const { + if (bitwidth != 32 && bitwidth != 64) + return ""; + unsigned vecSize = numel * bitwidth; + std::string isaPrefix; + if (vecSize == 128) { + isaPrefix = "b"; + } else if (vecSize == 256) { + isaPrefix = "d"; + } else if (vecSize == 512) { + isaPrefix = "e"; + } else { + return ""; + } + std::string fnName = "_ZGV" + isaPrefix + "N" + std::to_string(numel); + for (auto operand : operands) + fnName += "v"; + return fnName + "_" + baseName + (bitwidth == 32 ? "f" : ""); } +private: + std::string baseName; +}; + +class SleefNameGenerator { +public: + SleefNameGenerator(StringRef baseName, unsigned ulp = 10) + : baseName(baseName), ulp(std::to_string(ulp)) {} + + std::string operator()(unsigned bitwidth, unsigned numel, + ValueRange /*operands*/) const { + if (bitwidth != 32 && bitwidth != 64) + return ""; + unsigned vecSize = numel * bitwidth; + if (vecSize < 128) + return ""; + return "Sleef_" + baseName + (bitwidth == 32 ? "f" : "d") + + std::to_string(numel) + "_u" + ulp; + } + +private: + std::string baseName; + std::string ulp; +}; + +template struct VecOpToVecLib : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + VecOpToVecLib(MLIRContext *context, GetVecFnNameFn getVecFnName) + : OpRewritePattern(context), getVecFnName(getVecFnName) {} + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { VectorType vecTy = dyn_cast(op.getType()); if (!vecTy || vecTy.getRank() > 1) return failure(); - Type elemTy = vecTy.getElementType(); - if (!elemTy.isF32() && !elemTy.isF64()) - return failure(); - - auto fnName = use_sleef - ? getSleefName(elemTy.isF32(), vecTy.getNumElements()) - : getLibmvecName(elemTy.isF32(), vecTy.getNumElements(), - op->getOperands()); + auto fnName = getVecFnName(vecTy.getElementTypeBitWidth(), + vecTy.getNumElements(), op->getOperands()); if (fnName.empty()) return failure(); @@ -184,90 +225,68 @@ struct VecOpToLibmvecCall : public OpRewritePattern { return success(); } - std::string getLibmvecName(bool isFp32, int64_t numElems, - ValueRange ops) const { - auto baseName = isFp32 ? fp32FnBaseName : fp64FnBaseName; - int64_t vecSize = numElems * (isFp32 ? 32 : 64); - std::string isaPrefix; - if (vecSize == 128) { - isaPrefix = "b"; - } else if (vecSize == 256) { - isaPrefix = "d"; - } else if (vecSize == 512) { - isaPrefix = "e"; - } else { - return ""; - } - std::string fnName = "_ZGV" + isaPrefix + "N" + std::to_string(numElems); - for (auto operand : ops) - fnName += "v"; - fnName += "_" + baseName; - return fnName; - } - - std::string getSleefName(bool isFp32, int64_t numElems) const { - int64_t vecSize = numElems * (isFp32 ? 32 : 64); - if (vecSize < 128) - return ""; - auto baseName = isFp32 ? fp32FnBaseName : (fp64FnBaseName + "d"); - return "Sleef_" + baseName + std::to_string(numElems) + "_u10"; - } - private: - std::string fp32FnBaseName; - std::string fp64FnBaseName; - bool use_sleef; + GetVecFnNameFn getVecFnName; }; template -void populatePatternsForOp(RewritePatternSet &patterns, StringRef fp32FnName, - StringRef fp64FnName, bool use_sleef) { +void populatePatternsForOp(RewritePatternSet &patterns, + GetVecFnNameFn getVecFnName) { patterns.add>(patterns.getContext()); patterns.add>(patterns.getContext()); - patterns.add>(patterns.getContext(), fp32FnName, - fp64FnName, use_sleef); + patterns.add>(patterns.getContext(), getVecFnName); } -struct MathToLibmvecPass - : public mlir::triton::cpu::impl::MathToLibmvecBase { - MathToLibmvecPass() = default; +struct MathToVecLibPass + : public mlir::triton::cpu::impl::MathToVecLibBase { + MathToVecLibPass() = default; - MathToLibmvecPass(bool use_sleef) { this->use_sleef = use_sleef; } + explicit MathToVecLibPass(VecLib lib) { this->lib = lib; } void runOnOperation() override { Operation *op = getOperation(); MLIRContext *context = op->getContext(); RewritePatternSet patterns(context); - populatePatternsForOp(patterns, "acosf", "acos", use_sleef); - populatePatternsForOp(patterns, "acoshf", "acosh", - use_sleef); - populatePatternsForOp(patterns, "asinf", "asin", use_sleef); - populatePatternsForOp(patterns, "asinhf", "asinh", - use_sleef); - populatePatternsForOp(patterns, "atanf", "atan", use_sleef); - populatePatternsForOp(patterns, "atanhf", "atanh", - use_sleef); - populatePatternsForOp(patterns, "cbrtf", "cbrt", use_sleef); - populatePatternsForOp(patterns, "cosf", "cos", use_sleef); - populatePatternsForOp(patterns, "coshf", "cosh", use_sleef); - populatePatternsForOp(patterns, "erff", "erf", use_sleef); - populatePatternsForOp(patterns, "expf", "exp", use_sleef); - populatePatternsForOp(patterns, "exp2f", "exp2", use_sleef); - populatePatternsForOp(patterns, "logf", "log", use_sleef); - populatePatternsForOp(patterns, "log2f", "log2", use_sleef); - populatePatternsForOp(patterns, "log10f", "log10", - use_sleef); - populatePatternsForOp(patterns, "log1pf", "log1p", - use_sleef); - populatePatternsForOp(patterns, "sinf", "sin", use_sleef); - populatePatternsForOp(patterns, "sinhf", "sinh", use_sleef); - populatePatternsForOp(patterns, "tanf", "tan", use_sleef); - populatePatternsForOp(patterns, "tanhf", "tanh", use_sleef); + + switch (lib) { + case VecLib::Mvec: { + populateCommonPatterns(patterns); + break; + } + case VecLib::Sleef: { + populateCommonPatterns(patterns); + break; + } + } if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); } + + template + void populateCommonPatterns(RewritePatternSet &patterns) const { + populatePatternsForOp(patterns, VecFnNameGenerator("acos")); + populatePatternsForOp(patterns, VecFnNameGenerator("acosh")); + populatePatternsForOp(patterns, VecFnNameGenerator("asin")); + populatePatternsForOp(patterns, VecFnNameGenerator("asinh")); + populatePatternsForOp(patterns, VecFnNameGenerator("atan")); + populatePatternsForOp(patterns, VecFnNameGenerator("atanh")); + populatePatternsForOp(patterns, VecFnNameGenerator("cbrt")); + populatePatternsForOp(patterns, VecFnNameGenerator("cos")); + populatePatternsForOp(patterns, VecFnNameGenerator("cosh")); + populatePatternsForOp(patterns, VecFnNameGenerator("erf")); + populatePatternsForOp(patterns, VecFnNameGenerator("exp")); + populatePatternsForOp(patterns, VecFnNameGenerator("exp2")); + populatePatternsForOp(patterns, VecFnNameGenerator("log")); + populatePatternsForOp(patterns, VecFnNameGenerator("log2")); + populatePatternsForOp(patterns, VecFnNameGenerator("log10")); + populatePatternsForOp(patterns, VecFnNameGenerator("log1p")); + populatePatternsForOp(patterns, VecFnNameGenerator("sin")); + populatePatternsForOp(patterns, VecFnNameGenerator("sinh")); + populatePatternsForOp(patterns, VecFnNameGenerator("tan")); + populatePatternsForOp(patterns, VecFnNameGenerator("tanh")); + } }; } // anonymous namespace @@ -276,13 +295,8 @@ namespace mlir { namespace triton { namespace cpu { -std::unique_ptr> createMathToLibmvecPass() { - return std::make_unique(); -} - -std::unique_ptr> -createMathToLibmvecPass(bool use_sleef) { - return std::make_unique(use_sleef); +std::unique_ptr> createMathToVecLibPass(VecLib lib) { + return std::make_unique(lib); } } // namespace cpu diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 3680bb20ace6..aa365779f407 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -18,12 +18,15 @@ #include #include -#include - namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { using namespace mlir::triton; + + py::enum_(m, "VecLib") + .value("libsleef", cpu::VecLib::Sleef) + .value("libmvec", cpu::VecLib::Mvec); + m.def("add_convert_memory_ops", [](mlir::PassManager &pm, bool useScalarLoops) { pm.addPass(mlir::triton::cpu::createConvertMemoryOps(useScalarLoops)); @@ -122,8 +125,8 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); }); - m.def("add_math_to_libmvec", [](mlir::PassManager &pm, bool use_sleef) { - pm.addPass(mlir::triton::cpu::createMathToLibmvecPass(use_sleef)); + m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib) { + pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib)); }); m.def("add_math_to_libm", [](mlir::PassManager &pm) { pm.addPass(mlir::createConvertMathToLibmPass()); From 24d4bafe6ec59a02d9ccc2142dfe99dcdc8bca5d Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Wed, 4 Sep 2024 13:26:04 -0700 Subject: [PATCH 105/165] [CPU] Add unit test for print with isSigned and several fixes (#132) --- .../Dialect/TritonCPU/IR/TritonCPUOps.td | 8 +- python/test/unit/language/print_helper.py | 24 +-- python/test/unit/language/test_subprocess.py | 114 +++++++++++++- .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 34 ++++- .../lib/TritonToTritonCPU/ConvertDebugOps.cpp | 8 +- third_party/cpu/runtime/cpu_runtime.cpp | 142 ++++++++++++------ 6 files changed, 267 insertions(+), 63 deletions(-) diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index bbd832c44621..1246a1162508 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -86,8 +86,12 @@ def TTC_PrintOp : TTC_Op<"print", [MemoryEffects<[MemWrite]>]> { It only takes a single scalar or vector (tensor) element. }]; - let arguments = (ins StrAttr:$prefix, BoolAttr:$hex, - Variadic>:$val); + let arguments = (ins + StrAttr:$prefix, + BoolAttr:$hex, + Variadic>:$val, + DenseI32ArrayAttr:$isSigned + ); let assemblyFormat = [{ $prefix attr-dict (`:` $val^ `:` type($val))? diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index dde1409c4519..07cc1cc7223c 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -35,10 +35,12 @@ def kernel_print(X, Y, BLOCK: tl.constexpr): @triton.jit -def kernel_device_print_scalar(SCALAR): +def kernel_device_print_scalars(SCALAR, INT, FLOAT): x = tl.load(SCALAR) # Triton should add a space after this prefix. print("x:", x) + print("int:", INT) + print("float:", FLOAT) @triton.jit @@ -99,19 +101,21 @@ def kernel_print_2d_tensor(X, Y, BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.co def test_print(func: str, data_type: str, device: str): - N = 128 # This value should match with test_print in test_subprocess.py. + N = 128 # This value should match with test_print in test_subprocess.py. + SCALAR = 42 + # TODO(antiagainst): Currently the warp count is chosen to make sure we don't have multiple # threads printing duplicated messages due to broadcasting. Improve print op lowering logic # to filter out duplicated data range. - num_warps = N // get_current_target_warp_size() + num_warps = N // (get_current_target_warp_size() if device != "cpu" else 1) x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type)) y = torch.zeros((N, ), dtype=x.dtype, device=device) if func == "device_print": kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) - elif func == "device_print_scalar": - scalar = torch.tensor(42, dtype=x.dtype, device=device) - kernel_device_print_scalar[(1, )](scalar, num_warps=num_warps) + elif func == "device_print_scalars": + scalar = torch.tensor(SCALAR, dtype=x.dtype, device=device) + kernel_device_print_scalars[(1, )](scalar, SCALAR, 3.14, num_warps=num_warps) elif func == "device_print_negative": x = -x kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) @@ -143,16 +147,16 @@ def test_print(func: str, data_type: str, device: str): kernel_print_2d_tensor[(1, )](x_2d_tensor, y, num_warps=num_warps, BLOCK_SIZE_X=BLOCK_SIZE_X, BLOCK_SIZE_Y=BLOCK_SIZE_Y) else: - assert f"Unknown kernel: {func}" - + assert False, f"Unknown kernel: {func}" if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ func != "print_multiple_args" and func != "device_print_multiple_args" and \ - func != "device_print_pointer" and func != "device_print_scalar" and func != "device_print_2d_tensor": + func != "device_print_pointer" and func != "device_print_scalars" and func != "device_print_2d_tensor": assert_close(y, x) # Wait until driver complete all the jobs for the device_print, especially test_subprocess # require this which captures stdout when child exits. - getattr(torch, device).synchronize() + if device != "cpu": + torch.cuda.synchronize() if __name__ == "__main__": diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index f1e415bbb426..bd0c80fca80e 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -8,6 +8,8 @@ import pytest +import triton + dir_path = os.path.dirname(os.path.realpath(__file__)) print_path = os.path.join(dir_path, "print_helper.py") torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] @@ -17,12 +19,18 @@ def is_interpreter(): return os.environ.get('TRITON_INTERPRET', '0') == '1' +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + # TODO: Print with multiple operands +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("func_type, data_type", [(fn, data_type) - for fn in ["device_print", "device_print_scalar"] + for fn in ["device_print", "device_print_scalars"] for data_type in torch_types] + [ ("print", "int32"), ("static_print", "int32"), @@ -40,6 +48,9 @@ def is_interpreter(): ("device_print_2d_tensor", "int32"), ]) def test_print(func_type: str, data_type: str, device: str): + if is_cpu() and (data_type == "float16" or func_type in ["device_print_pointer", "device_print_large"]): + pytest.skip("test_print for float16/pointer/large are not yet supported on CPU.") + proc = subprocess.run( [sys.executable, print_path, "test_print", func_type, data_type, device], capture_output=True, @@ -59,6 +70,11 @@ def test_print(func_type: str, data_type: str, device: str): # Constant for testing the printing of scalar values SCALAR_VAL = 42 + # TODO: Consider cases for signedness, overflow, and multiple pids (non-determinism). + if is_cpu(): + _check_cpu_print(proc.stdout.decode("UTF-8"), func_type, data_type, N, SCALAR_VAL) + return + # Format is # pid (, , ) idx (, , ...) (operand ) expected_lines = Counter() @@ -69,11 +85,15 @@ def test_print(func_type: str, data_type: str, device: str): if data_type.startswith("float"): line += ".000000" expected_lines[line] = 1 - elif func_type == "device_print_scalar": + elif func_type == "device_print_scalars": line = f"pid (0, 0, 0) idx () x: {SCALAR_VAL}" if data_type.startswith("float"): line += ".000000" expected_lines[line] = N + line = f"pid (0, 0, 0) idx () int: {SCALAR_VAL}" + expected_lines[line] = N + line = "pid (0, 0, 0) idx () float: 3.140000" + expected_lines[line] = N elif func_type == "device_print_negative": for i in range(N): line = f"pid (0, 0, 0) idx ({i:3}) x: {-i}" @@ -112,8 +132,11 @@ def test_print(func_type: str, data_type: str, device: str): for y in range(y_dim): expected_lines[f"pid (0, 0, 0) idx ({x}, {y:2}): {(x * y_dim + y)}"] = 1 + cpu_gpu_msg = "Both CPU and GPU backends are available. Using the GPU backend." actual_lines = Counter() for line in outs: + if line == cpu_gpu_msg: + continue # Trim the exact pointer address in the output--they can change per run. line = (line.split(':')[0] + ": 0x") if func_type == "device_print_pointer" else line actual_lines[line] += 1 @@ -125,3 +148,90 @@ def test_print(func_type: str, data_type: str, device: str): continue print(f'Expected line "{line}" {expected_lines[line]} time(s), but saw {actual_lines[line]} time(s)') assert all(delta == 0 for delta in diff.values()) + + +def _check_cpu_print(actual, func_type, data_type, N, SCALAR_VAL): + # An example of a tensor printing is like: + # (0, 0, 0) x: [ 0, 1, 2, 3, 4, 5, 6, 7, + # 8, 9, 10, 11, 12, 13, 14, 15, + # ... + # 120, 121, 122, 123, 124, 125, 126, 127] + PID_PREFIX = "(0, 0, 0)" + NEWLINE_WITH_PADDING = "\n" + " " * (len(PID_PREFIX + " x: [")) + if func_type in ("print", "device_print", "device_print_uint"): + expected = PID_PREFIX + " x: [" + for i in range(N): + offset = (1 << 31) if data_type == "uint32" else 0 + expected += f"{i + offset:3}" + if data_type.startswith("float"): + expected += ".0000" + if i == N - 1: + continue + expected += "," + if i % 8 == 7: + expected += NEWLINE_WITH_PADDING + else: + expected += " " + expected += "]" + elif func_type == "device_print_scalars": + expected = f"{PID_PREFIX} x: {SCALAR_VAL}" + if data_type.startswith("float"): + expected += ".000000" + expected += f"\n{PID_PREFIX} int: {SCALAR_VAL}" + expected += f"\n{PID_PREFIX} float: 3.140000" + elif func_type == "device_print_negative": + expected = PID_PREFIX + " x: [" + for i in range(N): + expected += f"{-i:4}" + if i == N - 1: + continue + expected += "," + if i % 8 == 7: + expected += NEWLINE_WITH_PADDING + else: + expected += " " + expected += "]" + elif func_type == "device_print_hex": + expected = PID_PREFIX + " x: [" + for i in range(N): + if data_type.endswith("8"): + expected += f"0x{i:02x}" + elif data_type.endswith("16"): + expected += f"0x{i:04x}" + elif data_type.endswith("32"): + expected += f"0x{i:08x}" + elif data_type.endswith("64"): + expected += f"0x{i:016x}" + if i == N - 1: + continue + expected += "," + if i % 8 == 7: + expected += NEWLINE_WITH_PADDING + else: + expected += " " + expected += "]" + elif func_type == "static_print": + expected = f" int32[constexpr[{N}]]" + elif func_type == "no_arg_print": + expected = f"{PID_PREFIX}: 0" + elif func_type == "print_no_arg": + expected = f"{PID_PREFIX} no arg" + elif func_type == "print_multiple_args" or func_type == "device_print_multiple_args": + expected = "" + for k in range(2): + expected += PID_PREFIX + ": [" + for i in range(N): + expected += f"{i:3}" if k == 0 else "1" + if i == N - 1: + continue + expected += "," + if i % 8 == 7: + expected += "\n" + " " * (len(PID_PREFIX + ": [")) + else: + expected += " " + expected += "]" + if k == 0: + expected += "\n" + + # Ignore the trailing new line. + assert actual[:-1] == expected diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index 9a39ed25b78d..ed26c2196a44 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -76,6 +76,26 @@ std::string getFormatSubstr(Value value, bool hex = false, return ""; } +// For printf, need to extend int32 or float64. +Value printfPromoteValue(RewriterBase &rewriter, Value value) { + auto *context = rewriter.getContext(); + auto type = value.getType(); + auto loc = UnknownLoc::get(context); + + bool isUnsigned = type.isUnsignedInteger(); + if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { + if (isUnsigned) { + return zext(ui32_ty, value); + } else { + return sext(i32_ty, value); + } + } else if (type.isBF16() || type.isF16() || type.isF32()) { + return fpext(f64_ty, value); + } + + return value; +} + LLVM::LLVMFuncOp getPrintFuncDecl(ConversionPatternRewriter &rewriter, bool printf) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); @@ -89,8 +109,8 @@ LLVM::LLVMFuncOp getPrintFuncDecl(ConversionPatternRewriter &rewriter, if (printf) argsType = {ptr_ty(ctx)}; else - argsType = {i32_ty, i32_ty, i32_ty, ptr_ty(ctx), - ptr_ty(ctx), i32_ty, i32_ty, i64_ty}; + argsType = {i32_ty, i32_ty, i32_ty, ptr_ty(ctx), ptr_ty(ctx), + i32_ty, i32_ty, i32_ty, i64_ty, i32_ty}; auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType, /*isVarArg*/ printf); @@ -131,12 +151,13 @@ void llPrintf(StringRef prefix, std::array pid, for (auto elem : pid) allArgs.push_back(elem); if (arg.has_value()) - allArgs.push_back(arg.value()); + allArgs.push_back(printfPromoteValue(rewriter, arg.value())); call(getPrintFuncDecl(rewriter, true), allArgs); } void llVectorPrint(std::array pid, StringRef prefix, Value ptr, - bool isInteger, uint32_t bitWidth, int64_t numElem, + bool isInteger, bool isSigned, uint32_t bitWidth, + int64_t numElem, bool hex, ConversionPatternRewriter &rewriter) { assert(!prefix.empty()); auto loc = UnknownLoc::get(rewriter.getContext()); @@ -150,8 +171,10 @@ void llVectorPrint(std::array pid, StringRef prefix, Value ptr, allArgs.push_back(prefixValue); allArgs.push_back(ptr); allArgs.push_back(i32_val(isInteger)); + allArgs.push_back(i32_val(isSigned)); allArgs.push_back(i32_val(bitWidth)); allArgs.push_back(i64_val(numElem)); + allArgs.push_back(i32_val(hex)); call(getPrintFuncDecl(rewriter, false), allArgs); } @@ -203,8 +226,9 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { // booleans and separate bit width. llVectorPrint(pid, op.getPrefix(), ptr, vecShapedType.getElementType().isInteger(), + op.getIsSigned()[0], vecShapedType.getElementTypeBitWidth(), - vecShapedType.getNumElements(), rewriter); + vecShapedType.getNumElements(), op.getHex(), rewriter); } else { // TODO: support 2D+ vector printing. std::string msg{op.getPrefix()}; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp index 59806592d5d8..8a83156e4c52 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -51,15 +51,19 @@ struct PrintOpConversion : public OpConversionPattern { // (tt.print doesn't accept vector types, so we have this intermediate op.) if (op.getNumOperands() == 0) { rewriter.create(loc, op.getPrefix(), op.getHex(), - ValueRange{}); + ValueRange{}, + llvm::SmallVector{}); } else { // triton_cpu.print takes up to one vector or scalar operand. It prints // each value as a separate print call like the GPU and interpreter. + assert(op.getNumOperands() == op.getIsSigned().size()); for (size_t i = 0; i < op.getNumOperands(); i++) { Value opr = op.getOperands()[i]; + llvm::SmallVector isSigned = {op.getIsSigned()[i]}; // TODO: Consider using memrefs for general N-dimensional vectors. rewriter.create(loc, op.getPrefix(), op.getHex(), - rewriter.getRemappedValue(opr)); + rewriter.getRemappedValue(opr), + isSigned); } } diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index 888eeb028a63..9f306ececb9d 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -24,10 +24,12 @@ const int ELEMS_PER_LINE = 8; struct FormatInfo { bool isInt; + bool isSigned; int bitWidth; int maxIntDigits; bool hasNegative; bool scientific; + bool isHex; }; template @@ -36,10 +38,10 @@ computeDigitInfoHelper(const void *array, size_t index) { T elem = static_cast(array)[index]; if (elem == 0) return {1, false}; - return {static_cast(std::log10(std::abs(elem))) + 1, elem < 0}; + return {static_cast(std::log10(elem >= 0 ? elem : -elem)) + 1, elem < 0}; } -std::pair computeDigitInfo(void *vec, int32_t isInt, +std::pair computeDigitInfo(void *vec, bool isInt, bool isSigned, int32_t bitWidth, size_t index) { if (isInt == 0) { @@ -50,28 +52,47 @@ std::pair computeDigitInfo(void *vec, int32_t isInt, else assert(false && "Unsupported bitWidth"); } else { - // TODO: Handle signed types? - if (bitWidth == 64) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 32) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 16) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 8) - return computeDigitInfoHelper(vec, index); - else - assert(false && "Unsupported bitWidth"); + if (isSigned) { + if (bitWidth == 64) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 32) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 16) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 8) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 1) + return computeDigitInfoHelper(vec, index); + } else { + if (bitWidth == 64) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 32) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 16) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 8) + return computeDigitInfoHelper(vec, index); + else if (bitWidth == 1) + return computeDigitInfoHelper(vec, index); + } + printf("bitWidth: %d\n", bitWidth); + assert(false && "Unsupported bitWidth"); } } -FormatInfo getFormatInfo(void *vec, bool isInt, int32_t bitWidth, - int64_t numElem) { +FormatInfo getFormatInfo(void *vec, bool isInt, bool isSigned, int32_t bitWidth, + int64_t numElem, bool isHex) { + if (isHex) { + assert(bitWidth >= 8 && bitWidth <= 64 && bitWidth % 8 == 0); + return {isInt, isSigned, bitWidth, bitWidth / 4, false, false, true}; + } // Compute the max/min widths for pretty printing. int maxIntDigits = 0; int minIntDigits = std::numeric_limits::max(); bool hasNegative = false; for (int64_t i = 0; i < numElem; ++i) { - auto [digits, negative] = computeDigitInfo(vec, isInt, bitWidth, i); + auto [digits, negative] = + computeDigitInfo(vec, isInt, isSigned, bitWidth, i); hasNegative |= negative; maxIntDigits = std::max(maxIntDigits, digits); minIntDigits = std::min(minIntDigits, digits); @@ -84,7 +105,8 @@ FormatInfo getFormatInfo(void *vec, bool isInt, int32_t bitWidth, scientific = maxIntDigits + 2 + (hasNegative ? 1 : 0) > MAX_FLOAT_WIDTH; scientific |= maxIntDigits - minIntDigits > 3; } - return {isInt, bitWidth, maxIntDigits, hasNegative, scientific}; + return {isInt, isSigned, bitWidth, maxIntDigits, + hasNegative, scientific, false}; } template @@ -94,9 +116,9 @@ void printElementHelper(std::stringstream &ss, const void *array, } void printElement(std::stringstream &ss, const void *vec, size_t index, - bool isInt, int bitWidth) { - if (isInt == 0) { - switch (bitWidth) { + const FormatInfo &formatInfo) { + if (!formatInfo.isInt) { + switch (formatInfo.bitWidth) { case 32: printElementHelper(ss, vec, index); break; @@ -107,42 +129,76 @@ void printElement(std::stringstream &ss, const void *vec, size_t index, assert(false && "Unsupported bitWidth"); } } else { - switch (bitWidth) { - case 64: - printElementHelper(ss, vec, index); - break; - case 32: - printElementHelper(ss, vec, index); - break; - case 16: - printElementHelper(ss, vec, index); - break; - case 8: - // TODO: Seems like not working well. Need to fix it. - printElementHelper(ss, vec, index); - break; - default: - assert(false && "Unsupported bitWidth"); + if (formatInfo.isSigned) { + switch (formatInfo.bitWidth) { + case 64: + printElementHelper(ss, vec, index); + break; + case 32: + printElementHelper(ss, vec, index); + break; + case 16: + printElementHelper(ss, vec, index); + break; + case 8: + // int8_t is printed as char. + ss << static_cast(static_cast(vec)[index]); + break; + case 1: + printElementHelper(ss, vec, index); + break; + default: + assert(false && "Unsupported bitWidth"); + } + } else { + switch (formatInfo.bitWidth) { + case 64: + printElementHelper(ss, vec, index); + break; + case 32: + printElementHelper(ss, vec, index); + break; + case 16: + printElementHelper(ss, vec, index); + break; + case 8: + ss << static_cast(static_cast(vec)[index]); + break; + case 1: + printElementHelper(ss, vec, index); + break; + default: + assert(false && "Unsupported bitWidth"); + } } } } void printFormattedElement(std::stringstream &ss, void *vec, size_t index, const FormatInfo &formatInfo) { + // Right now, the GPU's hex float doesn't work correctly. C++ has std:: + // hexfloat, but let's consider only hex integers for now. + if (formatInfo.isHex && formatInfo.isInt) { + ss << "0x" << std::hex << std::setw(formatInfo.maxIntDigits) + << std::setfill('0'); + printElement(ss, vec, index, formatInfo); + return; + } + int padding = 0; - auto [digits, negative] = - computeDigitInfo(vec, formatInfo.isInt, formatInfo.bitWidth, index); + auto [digits, negative] = computeDigitInfo( + vec, formatInfo.isInt, formatInfo.isSigned, formatInfo.bitWidth, index); if (!negative && formatInfo.hasNegative) padding++; if (formatInfo.scientific) { ss << std::scientific << std::setw(MAX_FLOAT_WIDTH) << std::setprecision(FLOAT_PREC) << std::string(padding, ' '); - printElement(ss, vec, index, formatInfo.isInt, formatInfo.bitWidth); + printElement(ss, vec, index, formatInfo); } else { padding += formatInfo.maxIntDigits - digits; ss << std::fixed << std::setprecision(FLOAT_PREC) << std::string(padding, ' '); - printElement(ss, vec, index, formatInfo.isInt, formatInfo.bitWidth); + printElement(ss, vec, index, formatInfo); } } } // namespace @@ -168,9 +224,11 @@ EXPORT void triton_assert(int32_t pid0, int32_t pid1, int32_t pid2, bool cond, // TODO: Implement for higher dimension vectors. EXPORT void triton_vector_print(int32_t pid0, int32_t pid1, int32_t pid2, const char *prefix, void *vec, int32_t isInt, - int32_t bitWidth, int64_t numElem) { + bool isSigned, int32_t bitWidth, + int64_t numElem, bool isHex) { - FormatInfo formatInfo = getFormatInfo(vec, isInt != 0, bitWidth, numElem); + FormatInfo formatInfo = + getFormatInfo(vec, isInt != 0, isSigned != 0, bitWidth, numElem, isHex); std::stringstream ss; ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix << "["; From 050a5ea0a56b201b3cbb865a0db62df4960c4149 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Wed, 4 Sep 2024 16:27:30 -0400 Subject: [PATCH 106/165] Refactor math tests + select vector lib backend via kernel option (#136) I've combined `test_libdevice` and `test_libmvec` into one `test_math` suite. This removes a great deal of code duplication. (While at it, I also sorted the test names in our github workflow file.) Moreover, I've changed how we select the vector library -- instead of using environment variables, we now can specify it via the kernel kwargs. Environment variables are generally more awkward to work with. Finally, I've made sleef the default vector library. It supports more functions and architectures, and since we now vendor it, we don't have to worry about portability issues. --- .github/workflows/build-test.yml | 11 +-- python/test/unit/cpu/test_libdevice.py | 64 ------------- python/test/unit/cpu/test_libmvec.py | 98 -------------------- python/test/unit/cpu/test_math.py | 122 +++++++++++++++++++++++++ third_party/cpu/backend/compiler.py | 33 +++++-- 5 files changed, 152 insertions(+), 176 deletions(-) delete mode 100644 python/test/unit/cpu/test_libdevice.py delete mode 100644 python/test/unit/cpu/test_libmvec.py create mode 100644 python/test/unit/cpu/test_math.py diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index f2c446076114..ba702077d317 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -82,24 +82,23 @@ jobs: run: | python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu python -m pytest -s -n 32 --device cpu \ + python/test/unit/cpu/test_math.py \ + python/test/unit/cpu/test_opt.py \ python/test/unit/language/test_annotations.py \ python/test/unit/language/test_block_pointer.py \ - python/test/unit/language/test_conversions.py \ python/test/unit/language/test_compile_errors.py \ + python/test/unit/language/test_conversions.py \ python/test/unit/language/test_decorator.py \ python/test/unit/language/test_pipeliner.py \ python/test/unit/language/test_random.py \ python/test/unit/language/test_standard.py \ + python/test/unit/runtime/test_autotuner.py \ python/test/unit/runtime/test_bindings.py \ + python/test/unit/runtime/test_cache.py \ python/test/unit/runtime/test_driver.py \ python/test/unit/runtime/test_jit.py \ python/test/unit/runtime/test_launch.py \ python/test/unit/runtime/test_subproc.py \ - python/test/unit/runtime/test_autotuner.py \ - python/test/unit/runtime/test_cache.py \ - python/test/unit/cpu/test_libdevice.py \ - python/test/unit/cpu/test_libmvec.py \ - python/test/unit/cpu/test_opt.py \ python/test/unit/test_debug_dump.py - name: Run lit tests diff --git a/python/test/unit/cpu/test_libdevice.py b/python/test/unit/cpu/test_libdevice.py deleted file mode 100644 index 07e0d5ff1ccf..000000000000 --- a/python/test/unit/cpu/test_libdevice.py +++ /dev/null @@ -1,64 +0,0 @@ -import os -import pytest -import torch - -import triton -import triton.language as tl -from triton.language.extra import libdevice - -torch.manual_seed(0) - - -def is_interpreter(): - return os.environ.get('TRITON_INTERPRET', '0') == '1' - - -def is_cpu(): - return not is_interpreter() and \ - triton.runtime.driver.active.get_current_target().backend == "cpu" - - -float_dtypes = ['bfloat16', 'float16', 'float32', 'float64'] - - -@pytest.mark.parametrize("dtype_str", float_dtypes) -@pytest.mark.parametrize("math_fn", [ - "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "expm1", "floor", - "isnan", "isinf", "log", "log1p", "log2", "log10", "rsqrt", "signbit", "sin", "sinh", "sqrt", "tan", "tanh" -]) -@pytest.mark.parametrize("size", [1, 4, 16, 64]) -def test_libdevice(dtype_str, math_fn, size, device): - if not is_cpu(): - pytest.skip("This test is CPU-specific") - - @triton.jit - def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): - idxs = tl.arange(0, BLOCK_SIZE) - x = tl.load(src + idxs) - y = getattr(libdevice, MATH_FN)(x) - tl.store(dst + idxs, y) - - src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) - - if math_fn == "acosh": - src = src.abs() + 1 - if math_fn == "isnan" or math_fn == "isinf": - indices = torch.randint(low=0, high=size, size=(size // 2, ), device=device) - for i in indices: - if math_fn == "isnan": - src[i] = float("nan") - else: - src[i] = float(("+" if i % 2 else "-") + "inf") - - if math_fn in ["isnan", "isinf", "signbit"]: - out_dtype = torch.bool - else: - out_dtype = getattr(torch, dtype_str) - - res = torch.empty(src.shape, dtype=out_dtype, device=device) - kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size) - if math_fn == "cbrt": - ref = src.pow(1 / 3) - else: - ref = getattr(src, math_fn)() - torch.testing.assert_close(ref, res) diff --git a/python/test/unit/cpu/test_libmvec.py b/python/test/unit/cpu/test_libmvec.py deleted file mode 100644 index 5873cc7f06a5..000000000000 --- a/python/test/unit/cpu/test_libmvec.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -import pytest -import torch - -import triton -import triton.language as tl -from triton.language.extra import libdevice - - -def is_interpreter(): - return os.environ.get('TRITON_INTERPRET', '0') == '1' - - -def is_cpu(): - return not is_interpreter() and \ - triton.runtime.driver.active.get_current_target().backend == "cpu" - - -def is_x86(): - return is_cpu() and \ - triton.runtime.driver.active.get_current_target().arch == "x86_64" - - -float_dtypes = ['bfloat16', 'float16', 'float32', 'float64'] - - -@pytest.mark.parametrize("dtype_str", float_dtypes) -@pytest.mark.parametrize("math_fn", ["cos", "exp", "exp2", "log", "log2", "sin"]) -@pytest.mark.parametrize("size", [1, 2, 4, 8, 16, 32, 64, 128]) -def test_tensor_math_fn(dtype_str, math_fn, size, device): - if not is_x86(): - pytest.skip("Vectorized libm calls are supported for x86 target only.") - - @triton.jit - def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): - idxs = tl.arange(0, BLOCK_SIZE) - x = tl.load(src + idxs) - y = getattr(x, MATH_FN)() - tl.store(dst + idxs, y) - - src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) - res = torch.empty(src.shape, dtype=getattr(torch, dtype_str), device=device) - meta = kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size) - ref = getattr(src, math_fn)() - torch.testing.assert_close(ref, res) - - # Check generated code calls vector math function - # FP16 and BF16 are casted to FP32 for math ops - elem_size = 8 if dtype_str == "float64" else 4 - data_size = size * elem_size - num_vec_calls = 0 - if data_size >= 16: - num_vec_calls = 1 - if data_size > 64: - num_vec_calls = data_size / 64 - prefix = "Sleef" if os.environ.get("TRITON_CPU_USE_SLEEF", "0") != "0" else "_ZGV" - assert meta.asm["asm"].count(prefix) == num_vec_calls - - -@pytest.mark.parametrize("dtype_str", float_dtypes) -@pytest.mark.parametrize("math_fn", [ - "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "log", "log2", - "log10", "sin", "sinh", "tan", "tanh" -]) -@pytest.mark.parametrize("size", [1, 2, 4, 8, 16, 32, 64, 128]) -def test_libdevice_math_fn(dtype_str, math_fn, size, device): - if not is_x86(): - pytest.skip("Vectorized libm calls are supported for x86 target only.") - - @triton.jit - def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): - idxs = tl.arange(0, BLOCK_SIZE) - x = tl.load(src + idxs) - y = getattr(libdevice, MATH_FN)(x) - tl.store(dst + idxs, y) - - src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) - if math_fn == "acosh": - src = src.abs() + 1 - res = torch.empty(src.shape, dtype=getattr(torch, dtype_str), device=device) - meta = kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size) - if math_fn == "cbrt": - ref = src.pow(1 / 3) - else: - ref = getattr(src, math_fn)() - torch.testing.assert_close(ref, res) - - # Check generated code calls vector math function - # FP16 and BF16 are casted to FP32 for math ops - elem_size = 8 if dtype_str == "float64" else 4 - data_size = size * elem_size - num_vec_calls = 0 - if data_size >= 16: - num_vec_calls = 1 - if data_size > 64: - num_vec_calls = data_size / 64 - prefix = "Sleef" if os.environ.get("TRITON_CPU_USE_SLEEF", "0") != "0" else "_ZGV" - assert meta.asm["asm"].count(prefix) == num_vec_calls diff --git a/python/test/unit/cpu/test_math.py b/python/test/unit/cpu/test_math.py new file mode 100644 index 000000000000..3538f042ea99 --- /dev/null +++ b/python/test/unit/cpu/test_math.py @@ -0,0 +1,122 @@ +import os +import pytest +import torch + +import triton +import triton.language as tl +from triton.language.extra import libdevice +from itertools import chain, product + + +def is_interpreter(): + return os.environ.get('TRITON_INTERPRET', '0') == '1' + + +def is_cpu(): + return not is_interpreter() and \ + triton.runtime.driver.active.get_current_target().backend == "cpu" + + +float_dtypes = ['bfloat16', 'float16', 'float32', 'float64'] +lib_prefix = { + "libsleef": "Sleef", + "libmvec": "_ZGV", +} +arch = triton.runtime.driver.active.get_current_target().arch + +vec_sizes = [1, 2, 4, 8, 16, 32, 64, 128] +scalar_sizes = [1, 4, 16, 64] + + +def check_num_vec_calls(meta, vec_lib, dtype_str, size): + # Check generated code calls vector math function + # FP16 and BF16 are casted to FP32 for math ops + elem_size = 8 if dtype_str == "float64" else 4 + data_size = size * elem_size + if data_size > 64: + num_vec_calls = data_size // 64 + elif data_size >= 16: + num_vec_calls = 1 + else: + num_vec_calls = 0 + assert meta.asm["asm"].count(lib_prefix[vec_lib]) == num_vec_calls + + +@pytest.mark.parametrize("vec_lib, size", + chain(product(["libsleef", "libmvec"], vec_sizes), product([None], scalar_sizes))) +@pytest.mark.parametrize("dtype_str", float_dtypes) +@pytest.mark.parametrize("math_fn", ["cos", "exp", "exp2", "log", "log2", "sin"]) +def test_tensor_math_fn(vec_lib, dtype_str, math_fn, size, device): + if not is_cpu(): + pytest.skip("This test is CPU-specific") + if vec_lib == "libmvec" and arch != "x86_64": + pytest.skip("Vectorized libm calls are supported for x86 target only.") + + @triton.jit + def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + idxs = tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = getattr(x, MATH_FN)() + tl.store(dst + idxs, y) + + src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) + res = torch.empty(src.shape, dtype=getattr(torch, dtype_str), device=device) + meta = kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size, vec_lib=vec_lib) + ref = getattr(src, math_fn)() + torch.testing.assert_close(ref, res) + + if vec_lib is not None: + check_num_vec_calls(meta, vec_lib, dtype_str, size) + + +@pytest.mark.parametrize("vec_lib, size", + chain(product(["libsleef", "libmvec"], vec_sizes), product([None], scalar_sizes))) +@pytest.mark.parametrize("dtype_str", float_dtypes) +@pytest.mark.parametrize("math_fn", [ + "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "expm1", "floor", + "isnan", "isinf", "log", "log1p", "log2", "log10", "rsqrt", "signbit", "sin", "sinh", "sqrt", "tan", "tanh" +]) +def test_libdevice_math_fn(vec_lib, dtype_str, math_fn, size, device): + if not is_cpu(): + pytest.skip("This test is CPU-specific") + if vec_lib == "libmvec" and arch != "x86_64": + pytest.skip("Vectorized libm calls are supported for x86 target only.") + + @triton.jit + def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + idxs = tl.arange(0, BLOCK_SIZE) + x = tl.load(src + idxs) + y = getattr(libdevice, MATH_FN)(x) + tl.store(dst + idxs, y) + + src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) + # Customize inputs + if math_fn == "acosh": + src = src.abs() + 1 + if math_fn == "isnan" or math_fn == "isinf": + indices = torch.randint(low=0, high=size, size=(size // 2, ), device=device) + for i in indices: + if math_fn == "isnan": + src[i] = float("nan") + else: + src[i] = float(("+" if i % 2 else "-") + "inf") + + # Generate reference output + if math_fn == "cbrt": + ref = src.pow(1 / 3) + else: + ref = getattr(src, math_fn)() + + res = torch.empty(src.shape, dtype=ref.dtype, device=device) + meta = kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size, vec_lib=vec_lib) + torch.testing.assert_close(ref, res) + + if vec_lib is None: + return + + # These are not implemented via extern library calls + native_impls = ["expm1", "floor", "isnan", "isinf", "rsqrt", "signbit", "sqrt"] + if math_fn not in native_impls: + check_num_vec_calls(meta, vec_lib, dtype_str, size) + else: + assert meta.asm["asm"].count(lib_prefix[vec_lib]) == 0 diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6d8cd42912d4..fdce803e3a97 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from types import ModuleType -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple from triton._C.libtriton import cpu, ir, llvm, passes from triton.backends.compiler import BaseBackend, GPUTarget @@ -19,6 +19,9 @@ def min_dot_size(target: GPUTarget): return lambda lhsType, rhsType: (4, 4, 4) +VecLib = cpu.passes.ttcpuir.VecLib + + @dataclass(frozen=True) class CPUOptions: # GPU-specific options are used in several places. @@ -36,6 +39,7 @@ class CPUOptions: enable_fp_fusion: bool = True max_num_imprecise_acc_default: int = 0 enable_fast_math: bool = True + vec_lib: Optional[str] = 'libsleef' # TODO: We may introduce CPU-specific options like # of cores. @@ -47,6 +51,18 @@ def hash(self): key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())]) return hashlib.sha256(key.encode("utf-8")).hexdigest() + def get_vec_lib(self) -> VecLib: + if self.vec_lib is None: + return None + # Parse enum from str here (instead of in parse_options()) because the options have to be JSON-serializable, + # and pybind enums are not serializable. + vec_lib = VecLib.__members__.get(self.vec_lib, None) + if vec_lib is None: + raise ValueError( + f"Unexpected value for vec_lib: {self.vec_lib}, should be one of {{{', '.join(VecLib.__members__.keys())}}}" + ) + return vec_lib + class CPUBackend(BaseBackend): @@ -63,7 +79,7 @@ def __init__(self, target: tuple) -> None: def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} - if not "enable_fast_math" in args: + if "enable_fast_math" not in args: args["enable_fast_math"] = os.getenv("TRITON_CPU_FAST_MATH", "1") != "0" return CPUOptions(**args) @@ -165,13 +181,14 @@ def make_llir(self, src, metadata, options): cpu.passes.ttcpuir.add_memory_op_to_llvmir(pm) cpu.passes.ttcpuir.add_atomic_ops_to_llvmir(pm) cpu.passes.ttcpuir.add_debug_ops_to_llvmir(pm) - vec_lib = None - if os.environ.get("TRITON_CPU_USE_LIBMVEC", "1") != "0": - vec_lib = cpu.passes.ttcpuir.VecLib.libmvec - if os.environ.get("TRITON_CPU_USE_SLEEF", "0") != "0": - vec_lib = cpu.passes.ttcpuir.VecLib.libsleef - if vec_lib is not None and self.cpu_arch == "x86_64" and "avx512f" in self.cpu_features: + + vec_lib_requirements = { + VecLib.libsleef: {"neon", "sse", "avx"}, + VecLib.libmvec: {"avx512f"}, + } + if (vec_lib := options.get_vec_lib()) and vec_lib_requirements[vec_lib] & self.cpu_features: cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib) + passes.convert.add_math_to_llvmir(pm) cpu.passes.ttcpuir.add_math_to_libm(pm) cpu.passes.ttcpuir.add_vector_to_llvmir(pm, options.enable_fast_math) From bc9952926938503fac5ec2a9351e42fe49a70971 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Wed, 4 Sep 2024 18:36:58 -0400 Subject: [PATCH 107/165] Vectorize expm1, sqrt, and floor using sleef (#137) --- python/test/unit/cpu/test_math.py | 7 +++++-- .../cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp | 17 ++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/python/test/unit/cpu/test_math.py b/python/test/unit/cpu/test_math.py index 3538f042ea99..9a897d378e6a 100644 --- a/python/test/unit/cpu/test_math.py +++ b/python/test/unit/cpu/test_math.py @@ -115,8 +115,11 @@ def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): return # These are not implemented via extern library calls - native_impls = ["expm1", "floor", "isnan", "isinf", "rsqrt", "signbit", "sqrt"] - if math_fn not in native_impls: + native_impls = { + "libmvec": {"expm1", "floor", "isnan", "isinf", "rsqrt", "signbit", "sqrt"}, + "libsleef": {"isnan", "isinf", "rsqrt", "signbit"}, + } + if math_fn not in native_impls[vec_lib]: check_num_vec_calls(meta, vec_lib, dtype_str, size) else: assert meta.asm["asm"].count(lib_prefix[vec_lib]) == 0 diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index 782026146979..1994639bf6eb 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -169,7 +169,12 @@ class MvecNameGenerator { class SleefNameGenerator { public: SleefNameGenerator(StringRef baseName, unsigned ulp = 10) - : baseName(baseName), ulp(std::to_string(ulp)) {} + : baseName(baseName), ulpSuffix(4, '\0') { + if (ulp == 0) + ulpSuffix = ""; + else + sprintf(ulpSuffix.data(), "_u%02u", ulp); + } std::string operator()(unsigned bitwidth, unsigned numel, ValueRange /*operands*/) const { @@ -179,12 +184,12 @@ class SleefNameGenerator { if (vecSize < 128) return ""; return "Sleef_" + baseName + (bitwidth == 32 ? "f" : "d") + - std::to_string(numel) + "_u" + ulp; + std::to_string(numel) + ulpSuffix; } private: std::string baseName; - std::string ulp; + std::string ulpSuffix; }; template struct VecOpToVecLib : public OpRewritePattern { @@ -256,6 +261,12 @@ struct MathToVecLibPass } case VecLib::Sleef: { populateCommonPatterns(patterns); + populatePatternsForOp(patterns, + SleefNameGenerator("expm1")); + populatePatternsForOp( + patterns, SleefNameGenerator("floor", /*ulp=*/0)); + populatePatternsForOp( + patterns, SleefNameGenerator("sqrt", /*ulp=*/5)); break; } } From ae0810a0b7266bed34881f7381e84201d29988ba Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 4 Sep 2024 18:19:46 -0500 Subject: [PATCH 108/165] Fix infinite optimization loop for mask optimization. (#138) Signed-off-by: Ilya Enkovich --- test/TritonCPU/optimize-masks.mlir | 20 +++++++++++++++++++ .../lib/TritonCPUTransforms/OptimizeMasks.cpp | 5 ++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/TritonCPU/optimize-masks.mlir b/test/TritonCPU/optimize-masks.mlir index 5ab482a565a6..470e0f6b3419 100644 --- a/test/TritonCPU/optimize-masks.mlir +++ b/test/TritonCPU/optimize-masks.mlir @@ -71,3 +71,23 @@ module { tt.return } } + +// ----- + +// Regression test for the infinite optimization loop bug. + +module { + tt.func public @remove_masks_in_for_loop(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> + %c15_i32 = arith.constant 15 : i32 + %c16_i32 = arith.constant 16 : i32 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<0.000000e+00> : vector<16xf32> + %0 = arith.addi %arg1, %c15_i32 : i32 + %1 = arith.divsi %0, %c16_i32 : i32 + tt.store %arg0, %1 : !tt.ptr + tt.return + } +} diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp index 332ed5c97c7b..e747ef16c957 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp @@ -116,8 +116,11 @@ struct CdivToDiv : public OpRewritePattern { } return false; }); - if (!replaced) + + if (!replaced) { + rewriter.eraseOp(newRes.getDefiningOp()); return failure(); + } return success(); } From a3ef9d4b877e05afec161bef682a38eaac156ecd Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Thu, 5 Sep 2024 13:54:17 -0400 Subject: [PATCH 109/165] Implement libdevice.trunc (#140) --- python/src/ir.cc | 4 ++++ python/test/unit/cpu/test_math.py | 4 ++-- python/triton/language/extra/cpu/libdevice.py | 5 +++++ third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp | 2 ++ .../cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp | 1 + .../cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp | 1 + 6 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/src/ir.cc b/python/src/ir.cc index da11b0f39087..a193b4d6a583 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1709,6 +1709,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &val) -> Value { return self.create(val); }) + .def("create_trunc", + [](TritonOpBuilder &self, Value &val) -> Value { + return self.create(val); + }) .def("create_reduce", [](TritonOpBuilder &self, std::vector operands, int axis) -> OpState { return self.create(operands, axis); }) diff --git a/python/test/unit/cpu/test_math.py b/python/test/unit/cpu/test_math.py index 9a897d378e6a..793bd0e6f7e7 100644 --- a/python/test/unit/cpu/test_math.py +++ b/python/test/unit/cpu/test_math.py @@ -74,7 +74,7 @@ def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): @pytest.mark.parametrize("dtype_str", float_dtypes) @pytest.mark.parametrize("math_fn", [ "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "expm1", "floor", - "isnan", "isinf", "log", "log1p", "log2", "log10", "rsqrt", "signbit", "sin", "sinh", "sqrt", "tan", "tanh" + "isnan", "isinf", "log", "log1p", "log2", "log10", "rsqrt", "signbit", "sin", "sinh", "sqrt", "tan", "tanh", "trunc" ]) def test_libdevice_math_fn(vec_lib, dtype_str, math_fn, size, device): if not is_cpu(): @@ -116,7 +116,7 @@ def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): # These are not implemented via extern library calls native_impls = { - "libmvec": {"expm1", "floor", "isnan", "isinf", "rsqrt", "signbit", "sqrt"}, + "libmvec": {"expm1", "floor", "isnan", "isinf", "rsqrt", "signbit", "sqrt", "trunc"}, "libsleef": {"isnan", "isinf", "rsqrt", "signbit"}, } if math_fn not in native_impls[vec_lib]: diff --git a/python/triton/language/extra/cpu/libdevice.py b/python/triton/language/extra/cpu/libdevice.py index e442d0234d0f..d1b410fdd19b 100644 --- a/python/triton/language/extra/cpu/libdevice.py +++ b/python/triton/language/extra/cpu/libdevice.py @@ -124,6 +124,11 @@ def tanh(arg0, _builder=None): return core.tensor(_builder.create_tanh(arg0.handle), arg0.type) +@core.extern +def trunc(arg0, _builder=None): + return core.tensor(_builder.create_trunc(arg0.handle), arg0.type) + + @jit def _const(v, dtype): """ diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index 1994639bf6eb..68aa6c0bee0c 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -267,6 +267,8 @@ struct MathToVecLibPass patterns, SleefNameGenerator("floor", /*ulp=*/0)); populatePatternsForOp( patterns, SleefNameGenerator("sqrt", /*ulp=*/5)); + populatePatternsForOp( + patterns, SleefNameGenerator("trunc", /*ulp=*/0)); break; } } diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index d3d5ca95d70d..06ad1f1f6802 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -433,6 +433,7 @@ struct ConvertUnsupportedOps patterns.add>(context); patterns.add>(context); patterns.add>(context); + patterns.add>(context); } if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 7b3836898c2b..4c37524e1b5f 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -241,6 +241,7 @@ struct ConvertElementwiseOps patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); patterns.add>( typeConverter, context); From 9a68cf026c75767c17beed0a69d4942dadc46863 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 6 Sep 2024 10:28:33 -0500 Subject: [PATCH 110/165] Remove old LLVM bug workaround. (#141) Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index be8d914f9122..32b35c30df2b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2420,12 +2420,6 @@ def kernel(X, Z, BLOCK: tl.constexpr): def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested - # fpext fp16->fp32 is broken in LLVM for large vectors: - # https://github.com/llvm/llvm-project/issues/95278 - # TODO: remove the change after the bug is fixed. - if is_cpu() and dtype_str == "float16": - shape = (min(shape[0], 512), min(shape[1], 512)) - @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr, USE_I1: tl.constexpr): From 660948204ee97a60638df2dc2596b52c6509239c Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 9 Sep 2024 12:08:51 -0500 Subject: [PATCH 111/165] Add kernel execution time measurement using hooks for do_bench (#139) * Add timing measurements using launch hooks for CPU. Signed-off-by: Ilya Enkovich * Avoid OMP for trivial grid in CPU launcher. Signed-off-by: Ilya Enkovich * Add more measurement options for vector-add tutorial. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- python/triton/testing.py | 55 ++++++++++++++++++--- python/tutorials/01-vector-add.py | 80 +++++++++++++++++++++++++++---- third_party/cpu/backend/driver.py | 6 ++- 3 files changed, 123 insertions(+), 18 deletions(-) diff --git a/python/triton/testing.py b/python/triton/testing.py index 22adade12a35..ed47eca834b5 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -9,27 +9,63 @@ from typing import Any, Dict, List from . import language as tl from . import runtime +import triton class CPUDeviceInterface: - class Event: + class HooksTimeAccessor: - def __init__(self, enable_timing=True): - self.time = 0 + def __init__(self, di): + self.di = di + self.record_idx = 0 def elapsed_time(self, end_event) -> float: - return (end_event.time - self.time) * 1000 + total_time = 0 + for i in range(self.record_idx, end_event.record_idx): + total_time += self.di.kernel_times[i] + return total_time * 1000 def record(self): - self.time = time.perf_counter() + self.record_idx = len(self.di.kernel_times) + + class TimerEvent: + + def __init__(self): + self.timer = 0 + + def elapsed_time(self, end_event) -> float: + return (end_event.timer - self.timer) * 1000 + + def record(self): + self.timer = time.perf_counter() def __init__(self): - pass + self.kernel_times = [] + self.last_start = 0 + self.use_hooks = False + triton.compiler.CompiledKernel.launch_enter_hook = None + triton.compiler.CompiledKernel.launch_exit_hook = None + + def enable_hook_timing(self): + self.use_hooks = True + triton.compiler.CompiledKernel.launch_enter_hook = lambda arg: self._enter_hook() + triton.compiler.CompiledKernel.launch_exit_hook = lambda arg: self._exit_hook() def synchronize(self): pass + def _enter_hook(self): + self.last_start = time.perf_counter() + + def _exit_hook(self): + self.kernel_times.append(time.perf_counter() - self.last_start) + + def Event(self, enable_timing=True): + if self.use_hooks: + return CPUDeviceInterface.HooksTimeAccessor(self) + return CPUDeviceInterface.TimerEvent() + def nvsmi(attrs): attrs = ','.join(attrs) @@ -141,7 +177,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod return _summarize_statistics(ret, quantiles, return_mode) -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", measure_time_with_hooks=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -178,6 +214,11 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m di.synchronize() estimate_ms = start_event.elapsed_time(end_event) / 5 + # For CPU we can use entry and exit hooks to measure execution time + # more precisely. + if measure_time_with_hooks: + di.enable_hook_timing() + # compute number of warmup and repeat n_warmup = max(1, int(warmup / estimate_ms)) n_repeat = max(1, int(rep / estimate_ms)) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 222ad5359c37..672a32562422 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -26,6 +26,7 @@ DEVICE = triton.runtime.driver.active.get_active_torch_device() GPU_BLOCK_SIZE = 1024 CPU_BLOCK_SIZE = 4096 +CPU_ST_THRESHOLD = 65536 USE_GPU = False @triton.jit @@ -56,6 +57,26 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tl.store(output_ptr + offsets, output, mask=mask) +@triton.jit +def add_kernel_tiled(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + TILE_SIZE: tl.constexpr, # Number of elements each iteration should process. + # NOTE `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + for i in range(0, tl.cdiv(BLOCK_SIZE, TILE_SIZE)): + offsets = block_start + i * TILE_SIZE + tl.arange(0, TILE_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + # %% # Let's also declare a helper function to (1) allocate the `z` tensor # and (2) enqueue the above kernel with appropriate grid/block sizes: @@ -81,6 +102,28 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, device): return output +def add_tiled(x: torch.Tensor, y: torch.Tensor, output): + if output is None: + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel_tiled[grid](x, y, output, n_elements, BLOCK_SIZE=CPU_BLOCK_SIZE, TILE_SIZE=16) + return output + + +def add_tiled_with_st_threshold(x: torch.Tensor, y: torch.Tensor, output): + if output is None: + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + # TODO: try to choose the best block size using autotuner + BLOCK_SIZE = triton.next_power_of_2(n_elements) + if BLOCK_SIZE > CPU_ST_THRESHOLD: + BLOCK_SIZE = CPU_BLOCK_SIZE + add_kernel_tiled[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE, TILE_SIZE=16) + return output + + # %% # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: torch.manual_seed(0) @@ -94,10 +137,19 @@ def add(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, device): print(output_triton_cpu) print(f'The maximum difference between torch-cpu and triton-cpu is ' f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') +output_triton_cpu = add_tiled(x, y, None) +print(f'The maximum difference between torch-cpu-tiled and triton-cpu is ' + f'{torch.max(torch.abs(output_torch_cpu - output_triton_cpu))}') -LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu'] -LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU'] -LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '-')] +LINE_VALS = [ + 'triton-cpu', 'triton-cpu-hooks', 'triton-cpu-tiled', 'triton-cpu-tiled-hooks', 'triton-cpu-tiled-tuned-hooks', + 'torch-cpu' +] +LINE_NAMES = [ + 'TritonCPU', 'TritonCPU (hooks)', 'TritonCPUTiled', 'TritonCPUTiled (hooks)', 'TritonCPUTiled (tuned, hooks)', + 'TorchCPU' +] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('blue', '-'), ('blue', '-'), ('blue', '-'), ('green', '-')] if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() @@ -149,6 +201,7 @@ def benchmark(size, provider): triton.runtime.driver.set_active_to_cpu() else: triton.runtime.driver.set_active_to_gpu() + output = torch.empty_like(x) quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': @@ -159,17 +212,24 @@ def benchmark(size, provider): elif provider == 'torch-cpu': # Note that we preallocate the output buffer here to only measure the kernel performance # without a large chunk of memory allocation. - output = torch.empty_like(x) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles, device_type=DEVICE) - elif provider == 'triton-cpu-single': - output = torch.empty_like(x) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, - device_type=DEVICE) elif provider == 'triton-cpu': - output = torch.empty_like(x) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, True), quantiles=quantiles, + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, DEVICE), quantiles=quantiles, + device_type=DEVICE) + elif provider == 'triton-cpu-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, DEVICE), quantiles=quantiles, + device_type=DEVICE, measure_time_with_hooks=True) + elif provider == 'triton-cpu-tiled': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles, device_type=DEVICE) + elif provider == 'triton-cpu-tiled-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles, + device_type=DEVICE, measure_time_with_hooks=True) + elif provider == 'triton-cpu-tiled-tuned-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled_with_st_threshold(x, y, output), + quantiles=quantiles, device_type=DEVICE, + measure_time_with_hooks=True) gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 44d980e01987..9bc9db4379f8 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -229,9 +229,13 @@ def format_of(ty): static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ // TODO: Consider using omp collapse(3) clause for simplicity? - auto all_grids = get_all_grids(gridX, gridY, gridZ); size_t N = gridX * gridY * gridZ; + if (N == 1) {{ + (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} 0, 0, 0, 1, 1, 1); + return; + }} + auto all_grids = get_all_grids(gridX, gridY, gridZ); if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{ if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) printf("Single core launcher\\n"); From 9606a34e1d229ca1c0120b6c3ecce9ea7728f93e Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Mon, 9 Sep 2024 11:04:05 -0700 Subject: [PATCH 112/165] Use llvm_unreachable in cpu_runtime.cpp (#145) --- python/tutorials/01-vector-add.py | 1 + third_party/cpu/runtime/cpu_runtime.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 672a32562422..90f65b12e351 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -26,6 +26,7 @@ DEVICE = triton.runtime.driver.active.get_active_torch_device() GPU_BLOCK_SIZE = 1024 CPU_BLOCK_SIZE = 4096 +# Single Thread Threshold CPU_ST_THRESHOLD = 65536 USE_GPU = False diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index 9f306ececb9d..dba79828dc5b 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -1,3 +1,4 @@ +#include "llvm/Support/ErrorHandling.h" #include #include #include @@ -43,14 +44,13 @@ computeDigitInfoHelper(const void *array, size_t index) { std::pair computeDigitInfo(void *vec, bool isInt, bool isSigned, int32_t bitWidth, size_t index) { - if (isInt == 0) { if (bitWidth == 32) return computeDigitInfoHelper(vec, index); else if (bitWidth == 64) return computeDigitInfoHelper(vec, index); else - assert(false && "Unsupported bitWidth"); + llvm_unreachable("Unsupported bitWidth"); } else { if (isSigned) { if (bitWidth == 64) @@ -76,7 +76,7 @@ std::pair computeDigitInfo(void *vec, bool isInt, bool isSigned, return computeDigitInfoHelper(vec, index); } printf("bitWidth: %d\n", bitWidth); - assert(false && "Unsupported bitWidth"); + llvm_unreachable("Unsupported bitWidth"); } } @@ -126,7 +126,7 @@ void printElement(std::stringstream &ss, const void *vec, size_t index, printElementHelper(ss, vec, index); break; default: - assert(false && "Unsupported bitWidth"); + llvm_unreachable("Unsupported bitWidth"); } } else { if (formatInfo.isSigned) { @@ -148,7 +148,7 @@ void printElement(std::stringstream &ss, const void *vec, size_t index, printElementHelper(ss, vec, index); break; default: - assert(false && "Unsupported bitWidth"); + llvm_unreachable("Unsupported bitWidth"); } } else { switch (formatInfo.bitWidth) { @@ -168,7 +168,7 @@ void printElement(std::stringstream &ss, const void *vec, size_t index, printElementHelper(ss, vec, index); break; default: - assert(false && "Unsupported bitWidth"); + llvm_unreachable("Unsupported bitWidth"); } } } From 77b2d2ebbd3ab1bec24034424020a424e1c410c9 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Mon, 9 Sep 2024 16:48:06 -0700 Subject: [PATCH 113/165] Fix undefined symbole error in libTritonCPURuntime.so (#146) --- third_party/cpu/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index fd55642022e4..2acf7a6b6f48 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -8,6 +8,7 @@ if(TRITON_BUILD_PYTHON_MODULE) endif() add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) +target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport) # Build and link sleef set(SLEEF_BUILD_SHARED_LIBS ON CACHE BOOL "Build sleef shared lib" FORCE) From f1e3f683dab5ffb9b01523184a03f5cf91b4f386 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Wed, 18 Sep 2024 16:01:05 +0200 Subject: [PATCH 114/165] [Dot3D test] Enable with lower block size (#117) This commit enables dot3d test and skips float16->flooat16 cases. In this cases compilation is a too long and accuracy lower than it expected in test case. --- python/test/unit/language/test_core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 32b35c30df2b..e75aac9f0fc3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4036,7 +4036,7 @@ def make_finite(x, dtype): [(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_str) for B in [1, 2, 4, 8] for num_warps in [1, 2, 4, 8, 16] - for BLOCK_M, BLOCK_N in [(32, 32)] + for BLOCK_M, BLOCK_N in [(32, 32) if not is_cpu() else (4, 4)] for M, N, K in [(64, 64, 64), (32, 32, 32)] for in_dtype_str, out_dtype_str in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')]] + @@ -4060,7 +4060,9 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") elif is_cpu(): - pytest.skip("Test is skipped due to too long execution time on CPU") + if out_dtype_str == "float16": + pytest.skip("Test is skipped due to float16 accuracy issue") + input_precision = "tf32" if in_dtype_str == 'float32' else "ieee" else: input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" if not is_interpreter() and (BLOCK_M < 16 or BLOCK_N < 16): From ab1f1aaa3398e5b665480e94aa26fed2c7c5a1f6 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 19 Sep 2024 20:37:04 -0500 Subject: [PATCH 115/165] Add an option to choose between default reduction lowering and our own. (#98) Signed-off-by: Ilya Enkovich --- third_party/cpu/backend/compiler.py | 2 +- .../cpu/include/TritonToTritonCPU/Passes.h | 2 + .../cpu/include/TritonToTritonCPU/Passes.td | 6 +++ .../TritonToTritonCPU/ConvertReductionOp.cpp | 37 ++++++++++++++----- third_party/cpu/triton_cpu.cc | 8 ++-- 5 files changed, 41 insertions(+), 14 deletions(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index fdce803e3a97..e96bffa8ccc4 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -123,7 +123,7 @@ def make_ttcir(mod, metadata, opt): cpu.passes.ttcpuir.add_convert_elem_manip_ops(pm) cpu.passes.ttcpuir.add_convert_dot_op(pm) cpu.passes.ttcpuir.add_convert_histogram_op(pm) - cpu.passes.ttcpuir.add_convert_reduction_op(pm) + cpu.passes.ttcpuir.add_convert_reduction_op(pm, False) cpu.passes.ttcpuir.add_convert_scan_op(pm) cpu.passes.ttcpuir.add_convert_cf_ops(pm) cpu.passes.ttcpuir.add_convert_atomic_ops(pm) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index ac2c03b6abf8..8ab1d2f5631e 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -28,6 +28,8 @@ std::unique_ptr> createConvertDotOp(); std::unique_ptr> createConvertControlFlowOps(); std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); +std::unique_ptr> +createConvertReductionOp(bool useMultiDimReductionOp); std::unique_ptr> createConvertScanOp(); std::unique_ptr> createConvertAtomicOps(); std::unique_ptr> createConvertDebugOps(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 161fec9babcd..3e5e2cd89ad3 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -113,6 +113,12 @@ def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> }]; let constructor = "mlir::triton::cpu::createConvertReductionOp()"; + let options = [ + Option<"useMultiDimReductionOp", "use-multidim-reduction-op", + "bool", /*default*/"false", + "Use vector::MultiDimReductionOp and its default lowering when possible.">, + ]; + let dependentDialects = ["mlir::arith::ArithDialect", "mlir::vector::VectorDialect", "mlir::scf::SCFDialect", diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp index e660edaf97a5..3ac14ad21ff9 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -14,8 +14,10 @@ namespace mlir { namespace triton { +namespace cpu { #define GEN_PASS_DEF_CONVERTREDUCTIONOP #include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace cpu } // namespace triton } // namespace mlir @@ -42,16 +44,20 @@ class ReductionConversionTarget : public ConversionTarget { struct ReduceOpConversion : public ReduceScanOpConversionBase { - using ReduceScanOpConversionBase::ReduceScanOpConversionBase; + ReduceOpConversion(bool useMultiDimReductionOp, + const TypeConverter &typeConverter, MLIRContext *context) + : ReduceScanOpConversionBase(typeConverter, context) { + this->useMultiDimReductionOp = useMultiDimReductionOp; + } LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // More simple cases with a single input and a single combine - // operation can utilize target-specific reduction operations like - // horizaontal vector operations. We detect such cases here and map - // them to the vector::MultiDimReductionOp. - if (succeeded(mapToMultiDimReductionOp(op, rewriter))) + // More simple cases with a single input and a single combine operation + // can be mapped to a vector::MultiDimReductionOp. The resulting code + // depends on a quality of LLVM backend and is not always perfect though. + if (useMultiDimReductionOp && + succeeded(mapToMultiDimReductionOp(op, rewriter))) return success(); return ReduceScanOpConversionBase::matchAndRewrite(op, adaptor, rewriter); @@ -249,13 +255,18 @@ struct ReduceOpConversion return rewriter.create(loc, resTy, initVal); } + +private: + bool useMultiDimReductionOp; }; struct ConvertReductionOp - : public triton::impl::ConvertReductionOpBase { - using ConvertReductionOpBase::ConvertReductionOpBase; + : public triton::cpu::impl::ConvertReductionOpBase { + ConvertReductionOp() = default; - ConvertReductionOp() : ConvertReductionOpBase() {} + ConvertReductionOp(bool useMultiDimReductionOp) { + this->useMultiDimReductionOp = useMultiDimReductionOp; + } void runOnOperation() override { MLIRContext *context = &getContext(); @@ -264,7 +275,8 @@ struct ConvertReductionOp TritonToTritonCPUTypeConverter typeConverter; ReductionConversionTarget convTarget(*context, typeConverter); RewritePatternSet patterns(context); - patterns.add(typeConverter, context); + patterns.add(useMultiDimReductionOp, typeConverter, + context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); @@ -281,6 +293,11 @@ std::unique_ptr> createConvertReductionOp() { return std::make_unique(); } +std::unique_ptr> +createConvertReductionOp(bool useMultiDimReductionOp) { + return std::make_unique(useMultiDimReductionOp); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index aa365779f407..0ab6dba9b7bc 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -46,9 +46,11 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_convert_histogram_op", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); }); - m.def("add_convert_reduction_op", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertReductionOp()); - }); + m.def("add_convert_reduction_op", + [](mlir::PassManager &pm, bool use_multidim_reduction_op) { + pm.addPass(mlir::triton::cpu::createConvertReductionOp( + use_multidim_reduction_op)); + }); m.def("add_convert_scan_op", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertScanOp()); }); From d3f1b311007a2ad46bf28520188bbbc092039edd Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Sat, 21 Sep 2024 01:32:32 -0700 Subject: [PATCH 116/165] Fix regressions due to rebasing to the latest upstream --- third_party/cpu/backend/compiler.py | 5 +++++ third_party/cpu/include/TritonToTritonCPU/Passes.h | 8 ++++++++ .../lib/TritonCPUTransforms/DecomposeFpConversions.cpp | 5 +---- .../cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp | 3 ++- .../cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp | 9 ++++++--- .../cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp | 3 ++- third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp | 3 ++- .../cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp | 3 ++- .../cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp | 3 ++- .../cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp | 3 ++- .../cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp | 3 ++- third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp | 3 ++- .../cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp | 3 ++- third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp | 3 ++- 14 files changed, 40 insertions(+), 17 deletions(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index e96bffa8ccc4..3cb957f33207 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -33,6 +33,8 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15", "fp8e4nv") + deprecated_fp8_dtypes: Tuple[str] = () allowed_dot_input_precisions: Tuple[str] = ("ieee", "tf32", "tf32x3") allow_fp8e4nv: bool = True allow_fp8e4b15: bool = True @@ -81,6 +83,9 @@ def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} if "enable_fast_math" not in args: args["enable_fast_math"] = os.getenv("TRITON_CPU_FAST_MATH", "1") != "0" + if "supported_fp8_dtypes" not in args: + supported_fp8_dtypes = set(CPUOptions.supported_fp8_dtypes) + args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) return CPUOptions(**args) def pack_metadata(self, metadata): diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 8ab1d2f5631e..147aa28da18e 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -37,6 +37,14 @@ std::unique_ptr> createConvertDebugOps(); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonToTritonCPU/Passes.h.inc" +inline LogicalResult applyPartialConversionNoBuildMaterializations( + Operation *op, const ConversionTarget &target, + const FrozenRewritePatternSet &patterns, + ConversionConfig config = ConversionConfig()) { + config.buildMaterializations = false; + return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config); +} + } // namespace cpu } // namespace triton diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp index be5585347ba2..66410855960a 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -376,10 +376,7 @@ Value convertFp32ToFp8E5M2B16Rtne(Location loc, Value src, FpToFpConvFn getFpToFpConversionFn(Type srcTy, Type dstTy, std::optional roundMode) { - // TODO: Float8E4M3FNUZType is used for both float8e4nv and float8e4b8 by - // frontend. float8e4b8 tests are skipped for CPU so we interpret this type as - // float8e4nv. Needs to be fixed. See get_fp8e4nv_ty at ir.cc. - auto F8E4M3TyID = TypeID::get(); + auto F8E4M3TyID = TypeID::get(); auto F8E5M2TyID = TypeID::get(); auto F8E5M2B16TyID = TypeID::get(); auto F16TyID = TypeID::get(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp index bab0cd94c57e..281dd8474b30 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp @@ -192,7 +192,8 @@ struct ConvertAtomicOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp index 491b647103a7..fddbf3a50583 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp @@ -168,7 +168,8 @@ struct ConvertControlFlowOps RewritePatternSet patterns(context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } @@ -179,7 +180,8 @@ struct ConvertControlFlowOps { RewritePatternSet patterns(context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } @@ -195,7 +197,8 @@ struct ConvertControlFlowOps RewritePatternSet patterns(context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp index 8a83156e4c52..c60ef5a2ed2c 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -109,7 +109,8 @@ struct ConvertDebugOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp index 06cfb0d834d0..102ad1f81448 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp @@ -110,7 +110,8 @@ struct ConvertDotOp : public triton::impl::ConvertDotOpBase { RewritePatternSet patterns(context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp index a39a93e42446..31d80f115e75 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp @@ -229,7 +229,8 @@ struct ConvertElemManipOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 4c37524e1b5f..b0142d69ea16 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -253,7 +253,8 @@ struct ConvertElementwiseOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp index 0bcbfcc9f264..f5e18763f168 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp @@ -114,7 +114,8 @@ struct ConvertHistogramOp RewritePatternSet patterns(context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 02a458986269..814552f1691b 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -914,7 +914,8 @@ struct ConvertMemoryOps patterns.add(axisInfoAnalysis, shapeInfoAnalysis, pointerConverter, useScalarLoops, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp index 27f49a3078c1..3994723e4a26 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp @@ -177,7 +177,8 @@ struct ConvertPtrOps : public triton::impl::ConvertPtrOpsBase { patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp index 3ac14ad21ff9..68d65ffb3ae5 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -278,7 +278,8 @@ struct ConvertReductionOp patterns.add(useMultiDimReductionOp, typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp index fef15b046621..0697b4540252 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp @@ -136,7 +136,8 @@ struct ConvertScanOp : public triton::impl::ConvertScanOpBase { RewritePatternSet patterns(context); patterns.add(typeConverter, context); - if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversionNoBuildMaterializations( + mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; From 819ea4373f6e6804f9cb909e1d4bd7a37daeb1cb Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Sun, 22 Sep 2024 01:50:32 -0700 Subject: [PATCH 117/165] Update build-test.yml for pybind11 The result upstream changes need pybind11. --- .github/workflows/build-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index ba702077d317..a5178e8f34c8 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -67,7 +67,7 @@ jobs: - name: Install pip and apt dependencies run: | python3 -m pip install --upgrade pip - python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit + python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit pybind11 sudo apt-get update sudo apt-get install -y zlib1g-dev g++ pip install torch==2.1.2 From 73e909d6a1d6f36c62c7fed5ac8219ba040911e8 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Mon, 23 Sep 2024 20:28:01 +0200 Subject: [PATCH 118/165] [FP8 support] Enable Float8 tests failed after rebase (#151) * [FP8 support] Enable Float8 tests failed after rebase This commit adds list of supported dtypes on CPU for test_fp8_support from test_compile_errors.py suit. Also in upstream in ir.cc updated get_fp8e4nv_ty and currently returns "Float8E4M3FNType" instead of "Float8E4M3FNUZType". So I am updating corresponding conversions to use similar type. * Fix exception check in test_typeconvert_upcast. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich Co-authored-by: Ilya Enkovich --- .../test/unit/language/test_compile_errors.py | 10 +++++++++- python/test/unit/language/test_conversions.py | 7 +++---- python/test/unit/language/test_core.py | 2 ++ third_party/cpu/backend/compiler.py | 2 +- .../DecomposeFpConversions.cpp | 20 +++++++------------ 5 files changed, 22 insertions(+), 19 deletions(-) diff --git a/python/test/unit/language/test_compile_errors.py b/python/test/unit/language/test_compile_errors.py index 3e168e2bb55e..8fcccb94df97 100644 --- a/python/test/unit/language/test_compile_errors.py +++ b/python/test/unit/language/test_compile_errors.py @@ -7,7 +7,7 @@ import triton.language as tl from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure import traceback -from triton._internal_testing import is_cuda, is_hip, is_hip_mi300 +from triton._internal_testing import is_interpreter, is_cuda, is_hip, is_hip_mi300 def format_exception(type, value, tb): @@ -15,6 +15,10 @@ def format_exception(type, value, tb): return "\n".join(list_msg) +def is_cpu(): + return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cpu" + + def test_err_undefined_variable(): @triton.jit @@ -375,6 +379,10 @@ def test_fp8_support(fresh_triton_cache, dtype): elif is_hip(): if is_hip_mi300(): supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16] + elif is_interpreter(): + supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15] + elif is_cpu(): + supported_dtypes = [tl.float8e5, tl.float8e5b16, tl.float8e4nv] @triton.jit def dtype_kernel(dtype: tl.constexpr): diff --git a/python/test/unit/language/test_conversions.py b/python/test/unit/language/test_conversions.py index 25607c3dbafd..14c46000c1ef 100644 --- a/python/test/unit/language/test_conversions.py +++ b/python/test/unit/language/test_conversions.py @@ -283,12 +283,10 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device): launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) return - if src_dtype in ('float8e4b8', 'float8e4b15') and is_cpu(): - pytest.skip(f"Conversion from {src_dtype} to {dst_dtype} is not supported on CPU") - if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9)) or (src_dtype in ('float8e4b15') and is_hip()) - or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or not is_hip_mi300()))): + or (src_dtype in ('float8e4b8', 'float8e5b16') and (is_cuda() or (is_hip() and not is_hip_mi300()))) + or (src_dtype in ('float8e4b8', 'float8e4b15') and is_cpu())): # If the dtype should error out in the given device, we assert that and return with pytest.raises(triton.CompilationError, match="not supported in this architecture"): launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device) @@ -333,6 +331,7 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device): ]) def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device): if is_cpu() and dst_dtype not in ['float8e5', 'float8e4nv', 'float8e5b16']: + # TODO check if 'float8e4b15' downcast is fine for cpu if it will enable in this test pytest.skip(f"Conversion from {src_dtype} to {dst_dtype} is not supported on CPU") if is_cuda(): diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e75aac9f0fc3..3caf302fc83d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1234,6 +1234,8 @@ def test_abs_fp8(in_dtype, device): pytest.skip("float8e4b15 not supported on CUDA >= 9.0") if in_dtype == tl.float8e4nv and cc < (8, 9): pytest.skip("float8e4nv not supported on CUDA < 8.9") + elif is_cpu(): + pytest.skip('CPU not supports "fp8e4b15"') @triton.jit def abs_kernel(X, Z, SIZE: tl.constexpr): diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 3cb957f33207..9b3ed83674f6 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -33,7 +33,7 @@ class CPUOptions: cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False - supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15", "fp8e4nv") + supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e5b16", "fp8e4nv") deprecated_fp8_dtypes: Tuple[str] = () allowed_dot_input_precisions: Tuple[str] = ("ieee", "tf32", "tf32x3") allow_fp8e4nv: bool = True diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp index 66410855960a..4a4c8bd8e448 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -222,15 +222,13 @@ Value convertToFp8(Location loc, Value src, Type dstFpTy, int dstExpBits, Value convertFp16ToFp8E4M3Rtz(Location loc, Value src, PatternRewriter &rewriter) { - // TODO: Fix type to Float8E4M3FN. - return convertToFp8(loc, src, rewriter.getFloat8E4M3FNUZType(), 4, 7, false, + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, false, false, rewriter); } Value convertFp16ToFp8E4M3Rtne(Location loc, Value src, PatternRewriter &rewriter) { - // TODO: Fix type to Float8E4M3FN. - return convertToFp8(loc, src, rewriter.getFloat8E4M3FNUZType(), 4, 7, true, + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, true, false, rewriter); } @@ -287,19 +285,17 @@ Value convertFp16ToFp8E5M2B16Rtne(Location loc, Value src, Value convertBf16ToFp8E4M3Rtz(Location loc, Value src, PatternRewriter &rewriter) { - // TODO: Fix type to Float8E4M3FN. Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNUZType(), 4, 7, - false, false, rewriter); + return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNType(), 4, 7, false, + false, rewriter); } Value convertBf16ToFp8E4M3Rtne(Location loc, Value src, PatternRewriter &rewriter) { - // TODO: Fix type to Float8E4M3FN. Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNUZType(), 4, 7, true, + return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNType(), 4, 7, true, false, rewriter); } @@ -337,15 +333,13 @@ Value convertBf16ToFp8E5M2B16Rtne(Location loc, Value src, Value convertFp32ToFp8E4M3Rtz(Location loc, Value src, PatternRewriter &rewriter) { - // TODO: Fix type to Float8E4M3FN. - return convertToFp8(loc, src, rewriter.getFloat8E4M3FNUZType(), 4, 7, false, + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, false, false, rewriter); } Value convertFp32ToFp8E4M3Rtne(Location loc, Value src, PatternRewriter &rewriter) { - // TODO: Fix type to Float8E4M3FN. - return convertToFp8(loc, src, rewriter.getFloat8E4M3FNUZType(), 4, 7, true, + return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, true, false, rewriter); } From 539e6be229cbf37ce86168d5391a94dc1d0c64f6 Mon Sep 17 00:00:00 2001 From: Junyi Mei Date: Sat, 28 Sep 2024 01:12:28 +0800 Subject: [PATCH 119/165] Use 1-D vector reduction op to convert reduce op (#152) Add `useReductionOp` option for reduce op lowering. The `mapToMultiDimReductionOp` is renamed into `mapToReductionOp`, When the src is a 1-D vector, replace `triton::ReduceOp` with `vector::ReductionOp`. Currently, `triton::ReduceOp` is lowered in a generic way with shuffle and accumulate. However, some architecture support dedicated reduction instructions (e.g. `vredsum` in RVV), which cannot be inst-selected from current shuffle + accumulate pattern, and sub-optimal code might be generated. The 1-D reduction can be converted later for specific target in the Target TTCIR stage, or just leave it to LLVM for code generation. An alternate approach is keeping the `mapToMultiDimReductionOp` enabled by default, and let the multi-reduction be converted to `llvm.vector.reduce.*` intrinsics. But according to the previous PR, current code generation can lead to better results for 2-D reductions. Signed-off-by: Junyi Mei --- third_party/cpu/backend/compiler.py | 2 +- .../cpu/include/TritonToTritonCPU/Passes.h | 2 +- .../cpu/include/TritonToTritonCPU/Passes.td | 4 ++ .../TritonToTritonCPU/ConvertReductionOp.cpp | 43 +++++++++++++------ third_party/cpu/triton_cpu.cc | 5 ++- 5 files changed, 38 insertions(+), 18 deletions(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 9b3ed83674f6..1051156d9747 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -128,7 +128,7 @@ def make_ttcir(mod, metadata, opt): cpu.passes.ttcpuir.add_convert_elem_manip_ops(pm) cpu.passes.ttcpuir.add_convert_dot_op(pm) cpu.passes.ttcpuir.add_convert_histogram_op(pm) - cpu.passes.ttcpuir.add_convert_reduction_op(pm, False) + cpu.passes.ttcpuir.add_convert_reduction_op(pm, True, False) cpu.passes.ttcpuir.add_convert_scan_op(pm) cpu.passes.ttcpuir.add_convert_cf_ops(pm) cpu.passes.ttcpuir.add_convert_atomic_ops(pm) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 147aa28da18e..139081781fc5 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -29,7 +29,7 @@ std::unique_ptr> createConvertControlFlowOps(); std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); std::unique_ptr> -createConvertReductionOp(bool useMultiDimReductionOp); +createConvertReductionOp(bool useReductionOp, bool useMultiDimReductionOp); std::unique_ptr> createConvertScanOp(); std::unique_ptr> createConvertAtomicOps(); std::unique_ptr> createConvertDebugOps(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 3e5e2cd89ad3..27d38eb57eac 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -117,6 +117,10 @@ def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> Option<"useMultiDimReductionOp", "use-multidim-reduction-op", "bool", /*default*/"false", "Use vector::MultiDimReductionOp and its default lowering when possible.">, + + Option<"useReductionOp", "use-reduction-op", + "bool", /*default*/"false", + "Use vector::ReductionOp and its default lowering when possible.">, ]; let dependentDialects = ["mlir::arith::ArithDialect", diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp index 68d65ffb3ae5..6f3bfc8b2a01 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -44,9 +44,11 @@ class ReductionConversionTarget : public ConversionTarget { struct ReduceOpConversion : public ReduceScanOpConversionBase { - ReduceOpConversion(bool useMultiDimReductionOp, + ReduceOpConversion(bool useReductionOp, bool useMultiDimReductionOp, const TypeConverter &typeConverter, MLIRContext *context) : ReduceScanOpConversionBase(typeConverter, context) { + + this->useReductionOp = useReductionOp; this->useMultiDimReductionOp = useMultiDimReductionOp; } @@ -56,8 +58,8 @@ struct ReduceOpConversion // More simple cases with a single input and a single combine operation // can be mapped to a vector::MultiDimReductionOp. The resulting code // depends on a quality of LLVM backend and is not always perfect though. - if (useMultiDimReductionOp && - succeeded(mapToMultiDimReductionOp(op, rewriter))) + if (succeeded(mapToReductionOp(op, rewriter, useReductionOp, + useMultiDimReductionOp))) return success(); return ReduceScanOpConversionBase::matchAndRewrite(op, adaptor, rewriter); @@ -116,9 +118,10 @@ struct ReduceOpConversion return res; } - LogicalResult - mapToMultiDimReductionOp(triton::ReduceOp op, - ConversionPatternRewriter &rewriter) const { + LogicalResult mapToReductionOp(triton::ReduceOp op, + ConversionPatternRewriter &rewriter, + bool useReductionOp, + bool useMultiDimReductionOp) const { if (op.getNumOperands() != 1 || op.getNumResults() != 1) return failure(); @@ -156,9 +159,18 @@ struct ReduceOpConversion Type resTy = getTypeConverter()->convertType(op.getType(0)); Value acc = buildInitValue(op.getLoc(), resTy, reductionKind, rewriter); int64_t axis = op.getAxis(); - rewriter.replaceOpWithNewOp( - op, resTy, reductionKind, src, acc, axis); - return success(); + + if (useReductionOp && srcTy.getShape().size() == 1) { + rewriter.replaceOpWithNewOp(op, resTy, reductionKind, + src, acc); + return success(); + } else if (useMultiDimReductionOp) { + rewriter.replaceOpWithNewOp( + op, resTy, reductionKind, src, acc, axis); + return success(); + } + + return failure(); } LogicalResult detectReductionKind(Operation *op, @@ -258,13 +270,15 @@ struct ReduceOpConversion private: bool useMultiDimReductionOp; + bool useReductionOp; }; struct ConvertReductionOp : public triton::cpu::impl::ConvertReductionOpBase { ConvertReductionOp() = default; - ConvertReductionOp(bool useMultiDimReductionOp) { + ConvertReductionOp(bool useReductionOp, bool useMultiDimReductionOp) { + this->useReductionOp = useReductionOp; this->useMultiDimReductionOp = useMultiDimReductionOp; } @@ -275,8 +289,8 @@ struct ConvertReductionOp TritonToTritonCPUTypeConverter typeConverter; ReductionConversionTarget convTarget(*context, typeConverter); RewritePatternSet patterns(context); - patterns.add(useMultiDimReductionOp, typeConverter, - context); + patterns.add(useReductionOp, useMultiDimReductionOp, + typeConverter, context); if (failed(applyPartialConversionNoBuildMaterializations( mod, convTarget, std::move(patterns)))) @@ -295,8 +309,9 @@ std::unique_ptr> createConvertReductionOp() { } std::unique_ptr> -createConvertReductionOp(bool useMultiDimReductionOp) { - return std::make_unique(useMultiDimReductionOp); +createConvertReductionOp(bool useReductionOp, bool useMultiDimReductionOp) { + return std::make_unique(useReductionOp, + useMultiDimReductionOp); } } // namespace cpu diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 0ab6dba9b7bc..065d740bca79 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -47,9 +47,10 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); }); m.def("add_convert_reduction_op", - [](mlir::PassManager &pm, bool use_multidim_reduction_op) { + [](mlir::PassManager &pm, bool use_reduction_op, + bool use_multidim_reduction_op) { pm.addPass(mlir::triton::cpu::createConvertReductionOp( - use_multidim_reduction_op)); + use_reduction_op, use_multidim_reduction_op)); }); m.def("add_convert_scan_op", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertScanOp()); From e4588ea02a356b4ffd5fffb27c96e3c724f3ae9f Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Fri, 27 Sep 2024 19:50:46 +0200 Subject: [PATCH 120/165] [Keep materialization] Turn on meterialization (#154) This commit modifies TypeConversion to allow conversion cast from vector<> to tensor<> and avoid usage of option that skips materialization. Signed-off-by: Dmitrii Makarenko --- .../cpu/include/TritonToTritonCPU/Passes.h | 8 -------- .../lib/TritonToTritonCPU/ConvertAtomicOps.cpp | 3 +-- .../TritonToTritonCPU/ConvertControlFlowOps.cpp | 9 +++------ .../lib/TritonToTritonCPU/ConvertDebugOps.cpp | 3 +-- .../cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp | 3 +-- .../TritonToTritonCPU/ConvertElemManipOps.cpp | 3 +-- .../TritonToTritonCPU/ConvertElementwiseOps.cpp | 3 +-- .../TritonToTritonCPU/ConvertHistogramOp.cpp | 3 +-- .../lib/TritonToTritonCPU/ConvertMemoryOps.cpp | 3 +-- .../cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp | 3 +-- .../TritonToTritonCPU/ConvertReductionOp.cpp | 3 +-- .../cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp | 3 +-- .../cpu/lib/TritonToTritonCPU/TypeConverter.cpp | 17 +++++++++++++++++ 13 files changed, 30 insertions(+), 34 deletions(-) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 139081781fc5..055e433371d9 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -37,14 +37,6 @@ std::unique_ptr> createConvertDebugOps(); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonToTritonCPU/Passes.h.inc" -inline LogicalResult applyPartialConversionNoBuildMaterializations( - Operation *op, const ConversionTarget &target, - const FrozenRewritePatternSet &patterns, - ConversionConfig config = ConversionConfig()) { - config.buildMaterializations = false; - return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config); -} - } // namespace cpu } // namespace triton diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp index 281dd8474b30..bab0cd94c57e 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertAtomicOps.cpp @@ -192,8 +192,7 @@ struct ConvertAtomicOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp index fddbf3a50583..491b647103a7 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertControlFlowOps.cpp @@ -168,8 +168,7 @@ struct ConvertControlFlowOps RewritePatternSet patterns(context); patterns.add>(typeConverter, context); patterns.add>(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } @@ -180,8 +179,7 @@ struct ConvertControlFlowOps { RewritePatternSet patterns(context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } @@ -197,8 +195,7 @@ struct ConvertControlFlowOps RewritePatternSet patterns(context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp index c60ef5a2ed2c..8a83156e4c52 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -109,8 +109,7 @@ struct ConvertDebugOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp index 102ad1f81448..06cfb0d834d0 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp @@ -110,8 +110,7 @@ struct ConvertDotOp : public triton::impl::ConvertDotOpBase { RewritePatternSet patterns(context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp index 31d80f115e75..a39a93e42446 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp @@ -229,8 +229,7 @@ struct ConvertElemManipOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index b0142d69ea16..4c37524e1b5f 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -253,8 +253,7 @@ struct ConvertElementwiseOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp index f5e18763f168..0bcbfcc9f264 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp @@ -114,8 +114,7 @@ struct ConvertHistogramOp RewritePatternSet patterns(context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 814552f1691b..02a458986269 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -914,8 +914,7 @@ struct ConvertMemoryOps patterns.add(axisInfoAnalysis, shapeInfoAnalysis, pointerConverter, useScalarLoops, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp index 3994723e4a26..27f49a3078c1 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertPtrOps.cpp @@ -177,8 +177,7 @@ struct ConvertPtrOps : public triton::impl::ConvertPtrOpsBase { patterns.add(typeConverter, context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp index 6f3bfc8b2a01..6f3f8112ca43 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -292,8 +292,7 @@ struct ConvertReductionOp patterns.add(useReductionOp, useMultiDimReductionOp, typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp index 0697b4540252..fef15b046621 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp @@ -136,8 +136,7 @@ struct ConvertScanOp : public triton::impl::ConvertScanOpBase { RewritePatternSet patterns(context); patterns.add(typeConverter, context); - if (failed(applyPartialConversionNoBuildMaterializations( - mod, convTarget, std::move(patterns)))) + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp index ce66f8faeb3e..1b078f20020e 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp @@ -15,6 +15,19 @@ TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { return VectorType::get(tensorTy.getShape(), elemTy); }); + addArgumentMaterialization([&](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> std::optional { + if (isa(type)) + return builder.create(loc, type, inputs) + .getResult(0); + llvm::errs() << "Inputs: "; + llvm::interleaveComma(inputs, llvm::errs()); + llvm::errs() << "\n"; + llvm::errs() << "Type: " << type << "\n"; + llvm_unreachable("Unexpected argument materizalization"); + }); + // Converted ops produce vectors instead of tensors. Provide conversion // here for users. addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, @@ -29,6 +42,10 @@ TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { if (isa(type)) return builder.create(loc, type, inputs) .getResult(0); + llvm::errs() << "Inputs: "; + llvm::interleaveComma(inputs, llvm::errs()); + llvm::errs() << "\n"; + llvm::errs() << "Type: " << type << "\n"; llvm_unreachable("Unexpected target materizalization"); }); } From 1ef9001060610e89c070f0b4b101c9b63d1c0175 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Mon, 30 Sep 2024 18:37:02 +0200 Subject: [PATCH 121/165] [Scalarization/Loops generation] Refactor and new pass/interfaces introduced (#123) This commit reworks functionality introduced in PR #119. The current approach decides to insert SCF cycles earlier - with TTC IR and with tensors instead of vectors. This allows the analysis to be separated by operation type and applied via an external interface, avoiding changes to the original operations. (ScalarizeInterface) The original scalarization logic is now separated and placed in ScalarizeUsingForOpPass. This reduces the complexity of the ConvertMemoryOps pass. Thus, the conversion is now separated from the transformation of TTC IR. --- bin/RegisterTritonDialects.h | 2 + .../Dialect/TritonCPU/IR/TritonCPUOps.td | 37 +- test/TritonCPU/convert-memory-ops.mlir | 2 +- test/TritonCPU/scalarize-memory-ops.mlir | 20 +- third_party/cpu/backend/compiler.py | 3 +- third_party/cpu/include/CMakeLists.txt | 1 + .../cpu/include/ScalarizePass/CMakeLists.txt | 1 + .../ScalarizePass/ScalarizeInterface.h | 33 + .../ScalarizePass/ScalarizeInterface.td | 52 ++ .../ScalarizePass/ScalarizeInterfaceImpl.h | 16 + .../cpu/include/TritonToTritonCPU/Passes.h | 51 +- .../cpu/include/TritonToTritonCPU/Passes.td | 22 +- .../cpu/lib/TritonToTritonCPU/CMakeLists.txt | 4 + .../TritonToTritonCPU/ConvertMemoryOps.cpp | 643 ++++-------------- .../TritonToTritonCPU/ScalarizeInterface.cpp | 277 ++++++++ .../ScalarizeUsingForOps.cpp | 387 +++++++++++ third_party/cpu/triton_cpu.cc | 12 +- 17 files changed, 1019 insertions(+), 544 deletions(-) create mode 100644 third_party/cpu/include/ScalarizePass/CMakeLists.txt create mode 100644 third_party/cpu/include/ScalarizePass/ScalarizeInterface.h create mode 100644 third_party/cpu/include/ScalarizePass/ScalarizeInterface.td create mode 100644 third_party/cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ScalarizeInterface.cpp create mode 100644 third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 41dc478fd7ce..dee36aa7fa92 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -17,6 +17,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h" #include "cpu/include/TritonCPUToLLVM/Passes.h" #include "cpu/include/TritonCPUTransforms/Passes.h" #include "cpu/include/TritonToTritonCPU/Passes.h" @@ -76,6 +77,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::cpu::registerTritonToTritonCPUPasses(); mlir::triton::cpu::registerTritonCPUTransformsPasses(); mlir::triton::cpu::registerTritonCPUToLLVMPasses(); + mlir::triton::cpu::registerTritonOpScalarizeExternalModels(registry); // TODO: register Triton & TritonGPU passes registry.insert { def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">; + +def TTC_LoadOp : TTC_Op<"load", [ + MemoryEffects<[MemRead]>, +]> { + let summary = "Load from a memref to triton tensor"; + + let description = [{ + Operation to allow load from allocated temporary buffer to triton tensor. + }]; + + let arguments = (ins AnyMemRef:$src); + + let results = (outs TT_Type:$result); + + let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)"; +} + +def TTC_StoreOp : TTC_Op<"store", [ + MemoryEffects<[MemWrite]>, +]> { + let summary = "Store triton tensor to memref"; + + let description = [{ + Operation to allow store triton tensor to allocated temporary buffer. + }]; + + let arguments = ( + ins + TT_Type:$src, + AnyMemRef:$dst + ); + + let assemblyFormat = "$src `,` $dst attr-dict `:` type($src) `,` type($dst)"; +} + def TTC_PrintOp : TTC_Op<"print", [MemoryEffects<[MemWrite]>]> { let summary = "Print at most a single scalar or vector (converted from tensor) on each line"; @@ -100,7 +135,7 @@ def TTC_PrintOp : TTC_Op<"print", [MemoryEffects<[MemWrite]>]> { let hasVerifier = 1; } -def TT_AssertOp : TTC_Op<"assert", [MemoryEffects<[MemWrite]>]> { +def TTC_AssertOp : TTC_Op<"assert", [MemoryEffects<[MemWrite]>]> { let summary = "For correctness checking"; let description = [{ Takes a condition tensor, a message string, a file string, a function string, and a line number. diff --git a/test/TritonCPU/convert-memory-ops.mlir b/test/TritonCPU/convert-memory-ops.mlir index 32f8630cab84..c98747269fdc 100644 --- a/test/TritonCPU/convert-memory-ops.mlir +++ b/test/TritonCPU/convert-memory-ops.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops=use-scalar-loops=false | FileCheck %s +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops | FileCheck %s // Convert strided masked loads to scalar loads. diff --git a/test/TritonCPU/scalarize-memory-ops.mlir b/test/TritonCPU/scalarize-memory-ops.mlir index f62bbb5765f7..f1934d9ffc14 100644 --- a/test/TritonCPU/scalarize-memory-ops.mlir +++ b/test/TritonCPU/scalarize-memory-ops.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops=use-scalar-loops=true -cse -canonicalize | FileCheck %s +// RUN: triton-opt %s -split-input-file -triton-cpu-scalarize -cse -canonicalize | FileCheck %s // Convert strided masked load and store to loops. Pointer and mask should be scalarized. // TODO: There is an optimization opportunity to fuse loops. @@ -18,9 +18,9 @@ // CHECK-NEXT: memref.store %{{.*}}, %[[ALLOCA1]][%[[IV1]]] : memref<128xf32> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: %[[VEC_VAL:.*]] = vector.transfer_read %[[ALLOCA1]][%c0], %{{.*}} {in_bounds = [true]} : memref<128xf32>, vector<128xf32> +// CHECK-NEXT: %[[TENSOR_VAL:.*]] = triton_cpu.load %[[ALLOCA1]] : memref<128xf32> -> tensor<128xf32> // CHECK-NEXT: %[[ALLOCA2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> -// CHECK-NEXT: vector.transfer_write %[[VEC_VAL]], %[[ALLOCA2]][%c0] {in_bounds = [true]} : vector<128xf32>, memref<128xf32> +// CHECK-NEXT: triton_cpu.store %[[TENSOR_VAL]], %[[ALLOCA2]] : tensor<128xf32>, memref<128xf32> // CHECK-NEXT: scf.for %[[IV2:.*]] = %c0 to %c128 step %c1 { // CHECK-NEXT: %[[IV2_I32:.*]] = arith.index_castui %[[IV2]] : index to i32 // CHECK-NEXT: %[[IDX2:.*]] = arith.muli %[[IV2_I32]], %c3_i32 : i32 @@ -59,10 +59,10 @@ module { // CHECK-LABEL: @indirect_masked_load_store // CHECK: %[[ALLOCA_VALS1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> // CHECK-NEXT: %[[ALLOCA_PTRS1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi64> -// CHECK-NEXT: vector.transfer_write %{{.*}}, %[[ALLOCA_PTRS1]][%c0] {in_bounds = [true]} : vector<128xi64>, memref<128xi64> -// CHECK-NEXT: %[[EXT_MASK:.*]] = arith.extui %{{.*}} : vector<128xi1> to vector<128xi8> +// CHECK-NEXT: triton_cpu.store %{{.*}}, %[[ALLOCA_PTRS1]] : tensor<128x!tt.ptr>, memref<128xi64> +// CHECK-NEXT: %[[EXT_MASK:.*]] = arith.extui %{{.*}} : tensor<128xi1> to tensor<128xi8> // CHECK-NEXT: %[[ALLOCA_MASK1:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi8> -// CHECK-NEXT: vector.transfer_write %[[EXT_MASK]], %[[ALLOCA_MASK1]][%c0] {in_bounds = [true]} : vector<128xi8>, memref<128xi8> +// CHECK-NEXT: triton_cpu.store %[[EXT_MASK]], %[[ALLOCA_MASK1]] : tensor<128xi8>, memref<128xi8> // CHECK-NEXT: scf.for %[[IV1:.*]] = %c0 to %c128 step %c1 { // CHECK-NEXT: %[[PTR1_INT:.*]] = memref.load %[[ALLOCA_PTRS1]][%[[IV1]]] : memref<128xi64> // CHECK-NEXT: %[[PTR1:.*]] = tt.int_to_ptr %[[PTR1_INT]] : i64 -> !tt.ptr @@ -75,13 +75,13 @@ module { // CHECK-NEXT: memref.store %{{.*}}, %[[ALLOCA_VALS1]][%[[IV1]]] : memref<128xf32> // CHECK-NEXT: } // CHECK-NEXT: } -// CHECK-NEXT: %[[VEC_VAL:.*]] = vector.transfer_read %[[ALLOCA_VALS1]][%c0], %{{.*}} {in_bounds = [true]} : memref<128xf32>, vector<128xf32> +// CHECK-NEXT: %[[TENSOR_VAL:.*]] = triton_cpu.load %[[ALLOCA_VALS1]] : memref<128xf32> -> tensor<128xf32> // CHECK: %[[ALLOCA_PTRS2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi64> -// CHECK-NEXT: vector.transfer_write %{{.*}}, %[[ALLOCA_PTRS2]][%c0] {in_bounds = [true]} : vector<128xi64>, memref<128xi64> +// CHECK-NEXT: triton_cpu.store %{{.*}}, %[[ALLOCA_PTRS2]] : tensor<128x!tt.ptr>, memref<128xi64> // CHECK-NEXT: %[[ALLOCA_MASK2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xi8> -// CHECK-NEXT: vector.transfer_write %[[EXT_MASK]], %[[ALLOCA_MASK2]][%c0] {in_bounds = [true]} : vector<128xi8>, memref<128xi8> +// CHECK-NEXT: triton_cpu.store %[[EXT_MASK]], %[[ALLOCA_MASK2]] : tensor<128xi8>, memref<128xi8> // CHECK-NEXT: %[[ALLOCA_VALS2:.*]] = memref.alloca() {alignment = 64 : i64} : memref<128xf32> -// CHECK-NEXT: vector.transfer_write %[[VEC_VAL]], %[[ALLOCA_VALS2]][%c0] {in_bounds = [true]} : vector<128xf32>, memref<128xf32> +// CHECK-NEXT: triton_cpu.store %[[TENSOR_VAL]], %[[ALLOCA_VALS2]] : tensor<128xf32>, memref<128xf32> // CHECK-NEXT: scf.for %[[IV2:.*]] = %c0 to %c128 step %c1 { // CHECK-NEXT: %[[PTR2_INT:.*]] = memref.load %[[ALLOCA_PTRS2]][%[[IV2]]] : memref<128xi64> // CHECK-NEXT: %[[PTR2:.*]] = tt.int_to_ptr %[[PTR1_INT]] : i64 -> !tt.ptr diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 1051156d9747..991d694e854c 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -122,7 +122,8 @@ def make_ttcir(mod, metadata, opt): # TTIR -> TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - cpu.passes.ttcpuir.add_convert_memory_ops(pm, True) + cpu.passes.ttcpuir.add_scalarize(pm) + cpu.passes.ttcpuir.add_convert_memory_ops(pm) cpu.passes.ttcpuir.add_convert_ptr_ops(pm) cpu.passes.ttcpuir.add_convert_elementwise_ops(pm) cpu.passes.ttcpuir.add_convert_elem_manip_ops(pm) diff --git a/third_party/cpu/include/CMakeLists.txt b/third_party/cpu/include/CMakeLists.txt index b4c91e794072..30282a4736c1 100644 --- a/third_party/cpu/include/CMakeLists.txt +++ b/third_party/cpu/include/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(ScalarizePass) add_subdirectory(TritonCPUToLLVM) add_subdirectory(TritonCPUTransforms) add_subdirectory(TritonToTritonCPU) diff --git a/third_party/cpu/include/ScalarizePass/CMakeLists.txt b/third_party/cpu/include/ScalarizePass/CMakeLists.txt new file mode 100644 index 000000000000..f03fb94e2b1a --- /dev/null +++ b/third_party/cpu/include/ScalarizePass/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_interface(ScalarizeInterface) diff --git a/third_party/cpu/include/ScalarizePass/ScalarizeInterface.h b/third_party/cpu/include/ScalarizePass/ScalarizeInterface.h new file mode 100644 index 000000000000..1b16ff935540 --- /dev/null +++ b/third_party/cpu/include/ScalarizePass/ScalarizeInterface.h @@ -0,0 +1,33 @@ +#ifndef MLIR_INTERFACES_SCALARIZE_INTERFACE_H_ +#define MLIR_INTERFACES_SCALARIZE_INTERFACE_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/LLVM.h" + +#include "mlir/IR/OpDefinition.h" + +/// Include the ODS generated interface header files. +#include "cpu/include/ScalarizePass/ScalarizeInterface.h.inc" + +namespace mlir { +namespace triton { +namespace cpu { + +mlir::Value computeScalarValue(mlir::Operation *scalarizationOp, + mlir::Value vals, + mlir::ArrayRef indices, + mlir::PatternRewriter &rewriter); + +mlir::Value computeScalarValue(mlir::Operation *scalarizationOp, + mlir::Value vals, mlir::ValueRange indices, + mlir::PatternRewriter &rewriter); + +bool canComputeScalarValue(mlir::Value vals); +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif // MLIR_INTERFACES_SCALARIZE_INTERFACE_H_ diff --git a/third_party/cpu/include/ScalarizePass/ScalarizeInterface.td b/third_party/cpu/include/ScalarizePass/ScalarizeInterface.td new file mode 100644 index 000000000000..7e6c4acecbcb --- /dev/null +++ b/third_party/cpu/include/ScalarizePass/ScalarizeInterface.td @@ -0,0 +1,52 @@ +#ifndef MLIR_SCALARIZEINTERFACE +#define MLIR_SCALARIZEINTERFACE + +include "mlir/IR/OpBase.td" + +def ScalarizeInterface : OpInterface<"ScalarizeInterface"> { + let description = [{ + Interface for allowing operations to expose information needed to + scalarize them or in simpler terms inserts SCF loops to reduce amount of + generated ir. Similar with checking operands of specific operations for + constancy - to understand is it possible to put it inside of loop's body. + }]; + let cppNamespace = "mlir::triton::cpu"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Checks operand and is ScalarizeInterface registered for this operation. + }], + /*retType=*/"bool", + /*methodName=*/"canComputeScalarValue", + /*args=*/(ins + "mlir::Value ":$vals) + >, + InterfaceMethod< + /*desc=*/[{ + Returns value that can be put inside of generated cycle and creates required constants. + Can go throught operands to check type of passed values. Implementation for static indeces. + }], + /*retType=*/"mlir::Value", + /*methodName=*/"computeScalarValue", + /*args=*/(ins + "mlir::Value ":$vals, + "mlir::ArrayRef ":$indices, + "mlir::PatternRewriter &":$rewriter) + >, + InterfaceMethod< + /*desc=*/[{ + Returns value that can be put inside of generated cycle and creates required constants. + Can go throught operands to check type of passed values. Implementation for dynamic indices + which is in common used in loops to iterate with Inductional Variable. + }], + /*retType=*/"mlir::Value", + /*methodName=*/"computeScalarValueForLoop", + /*args=*/(ins + "mlir::Value ":$vals, + "mlir::ValueRange ":$indices, + "mlir::PatternRewriter &":$rewriter) + > + ]; +} + +#endif // MLIR_SCALARIZEINTERFACE diff --git a/third_party/cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h b/third_party/cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h new file mode 100644 index 000000000000..ab2730d9acf8 --- /dev/null +++ b/third_party/cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h @@ -0,0 +1,16 @@ +#ifndef MLIR_DIALECT_TRITON_SCALARIZEINTERFACEIMPL_H +#define MLIR_DIALECT_TRITON_SCALARIZEINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace triton { +namespace cpu { + +void registerTritonOpScalarizeExternalModels(DialectRegistry ®istry); + +} // namespace cpu +} // namespace triton +} // namespace mlir + +#endif // MLIR_DIALECT_TRITON_SCALARIZEINTERFACEIMPL_H diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 055e433371d9..a84b69c4b754 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -4,6 +4,8 @@ #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "triton/Analysis/AxisInfo.h" +#include "llvm/ADT/TypeSwitch.h" #include @@ -21,8 +23,6 @@ namespace cpu { std::unique_ptr> createConvertElementwiseOps(); std::unique_ptr> createConvertElemManipOps(); std::unique_ptr> createConvertMemoryOps(); -std::unique_ptr> -createConvertMemoryOps(bool useScalarLoops); std::unique_ptr> createConvertPtrOps(); std::unique_ptr> createConvertDotOp(); std::unique_ptr> createConvertControlFlowOps(); @@ -34,9 +34,56 @@ std::unique_ptr> createConvertScanOp(); std::unique_ptr> createConvertAtomicOps(); std::unique_ptr> createConvertDebugOps(); +std::unique_ptr> createScalarizeUsingForOpPass(); + #define GEN_PASS_REGISTRATION #include "cpu/include/TritonToTritonCPU/Passes.h.inc" +template +constexpr bool is_one_of_v = (std::is_same_v || ...); + +template +constexpr bool is_memory_op_v = + is_one_of_v; + +inline mlir::Type getMemoryOpType(triton::LoadOp operation) { + return operation.getType(); +} + +inline mlir::Type getMemoryOpType(triton::StoreOp operation) { + return operation.getValue().getType(); +} + +inline ArrayRef getShape(mlir::Type type) { + return llvm::TypeSwitch>(type) + .Case([](ShapedType t) { return t.getShape(); }) + .Case([](RankedTensorType t) { return t.getShape(); }) + .Default([](Type t) { + llvm::errs() << "Attempt to getShape from unknow type: " << t << "\n"; + llvm_unreachable("Unsupported type in getShape"); + return ArrayRef(); + }); +} + +inline bool hasShape(mlir::Type type) { + return isa(type); +} + +template , bool> = true> +bool isContiguousRowMajorAccess(AxisInfo *axisInfo, OpTy op) { + if (!axisInfo) + return false; + + mlir::Type type = getMemoryOpType(op); + if (!hasShape(type)) { + return false; + } + auto shape = getShape(type); + auto contiguity = axisInfo->getContiguity(); + return (shape.back() > 1 && shape.back() == contiguity.back()); +} + } // namespace cpu } // namespace triton diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 27d38eb57eac..230731249783 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -10,12 +10,6 @@ def ConvertMemoryOps : Pass<"triton-cpu-convert-memory-ops", "mlir::ModuleOp"> { }]; let constructor = "mlir::triton::cpu::createConvertMemoryOps()"; - let options = [ - Option<"useScalarLoops", "use-scalar-loops", - "bool", /*default*/"true", - "Enable lowering of tensor loads and stores to scalar loops.">, - ]; - let dependentDialects = ["mlir::arith::ArithDialect", "mlir::memref::MemRefDialect", "mlir::vector::VectorDialect", @@ -171,4 +165,20 @@ def ConvertDebugOps : Pass<"triton-cpu-convert-debug-ops", "mlir::ModuleOp"> { "mlir::triton::cpu::TritonCPUDialect"]; } +def ScalarizeUsingForOp : Pass<"triton-cpu-scalarize", "mlir::ModuleOp"> { + let summary = "Insert Loops for ops, that are not vectorizable"; + let description = [{ + This pass is used to reduce compile time by generating loops for + operations that cannot be handled as vectors, and simply increases + the amount of IR without any further optimization. + }]; + + let constructor = "mlir::triton::cpu::createScalarizeUsingForOpPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + + #endif diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index 18e675044881..47023eff75be 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -6,6 +6,8 @@ add_triton_library(TritonToTritonCPU ConvertElementwiseOps.cpp ConvertElemManipOps.cpp ConvertHistogramOp.cpp + ScalarizeInterface.cpp + ScalarizeUsingForOps.cpp ConvertMemoryOps.cpp ConvertPtrOps.cpp ConvertReductionOp.cpp @@ -14,6 +16,8 @@ add_triton_library(TritonToTritonCPU DEPENDS TritonToTritonCPUPassIncGen + MLIRScalarizeInterfaceIncGen + MLIRDialectUtils LINK_LIBS PUBLIC TritonCPUIR diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 02a458986269..2ca6becadcfe 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -20,6 +20,8 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include "cpu/include/ScalarizePass/ScalarizeInterface.h" + namespace mlir { namespace triton { namespace cpu { @@ -43,11 +45,9 @@ struct MemoryOpConversion : public OpConversionPattern { MemoryOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, - TypeConverter &typeConverter, bool useScalarLoops, - MLIRContext *context) + TypeConverter &typeConverter, MLIRContext *context) : OpConversionPattern(typeConverter, context), - axisAnalysis(axisInfoAnalysis), shapeAnalysis(shapeInfoAnalysis), - genScalarLoops(useScalarLoops) {} + axisAnalysis(axisInfoAnalysis), shapeAnalysis(shapeInfoAnalysis) {} Value extractScalarPointer(Location loc, Value ptrs, ArrayRef indices, @@ -56,8 +56,9 @@ struct MemoryOpConversion : public OpConversionPattern { // compiler doesn't always optimize it to a simple scalar pointer // computation. Here we try to follow a data flow of the tensor to rebuild a // scalar pointer for more efficient resulting code. - if (canComputeScalarValue(ptrs)) - return computeScalarValue(ptrs, indices, rewriter); + if (canComputeScalarValue(ptrs)) { + return computeScalarValue(ptrs.getDefiningOp(), ptrs, indices, rewriter); + } // Fall back to a scalar pointer extraction from the vector. Value ptr = rewriter.create( @@ -67,206 +68,6 @@ struct MemoryOpConversion : public OpConversionPattern { return ptr; } - bool canComputeScalarValue(Value vals) const { - auto def = vals.getDefiningOp(); - if (!def) - return false; - - if (isa(*def)) { - for (auto op : def->getOperands()) { - if (!canComputeScalarValue(op)) - return false; - } - return true; - } - - if (isa(*def)) - return true; - - if (auto cst = dyn_cast(def)) { - if (auto denseVal = dyn_cast(cst.getValue())) { - return denseVal.isSplat(); - } - return false; - } - - return false; - } - - Value computeScalarValue(Value vals, ArrayRef indices, - ConversionPatternRewriter &rewriter) const { - if (auto def = vals.getDefiningOp()) { - return def.getSrc(); - } - - if (auto def = vals.getDefiningOp()) { - int32_t start = static_cast(def.getStart()); - assert(indices.size() == 1); - Type elemTy = cast(def.getType()).getElementType(); - return rewriter.create( - def.getLoc(), elemTy, - rewriter.getIntegerAttr(elemTy, start + indices[0])); - } - - if (auto def = vals.getDefiningOp()) { - // Find broadcasted dimensions and replace indices for those dimensions - // with 0 (broadcasted dimension always has size 1). - SmallVector newIndices; - auto sourceTy = cast(def.getSrc().getType()); - auto targetTy = cast(def.getType()); - assert(sourceTy.getRank() == indices.size() && "Mismatched rank"); - for (int64_t i = 0; i < sourceTy.getRank(); ++i) { - if (sourceTy.getShape()[i] != targetTy.getShape()[i]) - newIndices.push_back(0); - else - newIndices.push_back(indices[i]); - } - return computeScalarValue(def.getSrc(), newIndices, rewriter); - } - - if (auto def = vals.getDefiningOp()) { - // Remove index at expanded dimension. - SmallVector newIndices(indices); - newIndices.erase(newIndices.begin() + def.getAxis()); - return computeScalarValue(def.getSrc(), newIndices, rewriter); - } - - if (auto def = vals.getDefiningOp()) { - auto denseVal = cast(def.getValue()); - assert(denseVal.isSplat()); - auto scalarAttr = denseVal.getSplatValue(); - Value res = rewriter.create( - def.getLoc(), scalarAttr.getType(), scalarAttr); - return res; - } - - if (auto def = vals.getDefiningOp()) { - // Permute indices. - SmallVector newIndices; - auto order = def.getOrder(); - assert(indices.size() == order.size() && "Mismatched rank"); - for (auto idx : order) - newIndices.push_back(indices[idx]); - return computeScalarValue(def.getSrc(), newIndices, rewriter); - } - - // Generic case where we copy defining op with scalar operands. - auto def = vals.getDefiningOp(); - OperationState newState(def->getLoc(), def->getName()); - for (auto op : def->getOperands()) { - newState.operands.push_back(computeScalarValue(op, indices, rewriter)); - } - assert(def->getResults().size() == 1); - newState.types.push_back( - cast(def->getResultTypes()[0]).getElementType()); - newState.attributes = def->getAttrs(); - return rewriter.create(newState)->getResult(0); - } - - Value computeScalarValue(Value vals, ValueRange indices, - ConversionPatternRewriter &rewriter, - DenseMap &valMap) const { - if (valMap.count(vals)) - return valMap.at(vals); - - if (auto def = vals.getDefiningOp()) { - return def.getSrc(); - } - - if (auto def = vals.getDefiningOp()) { - auto denseVal = cast(def.getValue()); - assert(denseVal.isSplat()); - auto scalarAttr = denseVal.getSplatValue(); - Value res = rewriter.create( - def.getLoc(), scalarAttr.getType(), scalarAttr); - valMap[vals] = res; - return res; - } - - if (auto def = vals.getDefiningOp()) { - assert(indices.size() == 1); - int32_t start = static_cast(def.getStart()); - Type elemTy = cast(def.getType()).getElementType(); - Value startVal = rewriter.create( - def.getLoc(), elemTy, rewriter.getIntegerAttr(elemTy, start)); - Value index = indices[0]; - if (!elemTy.isIndex()) - index = - rewriter.create(def.getLoc(), elemTy, index); - Value res = - rewriter.create(def.getLoc(), elemTy, startVal, index); - valMap[vals] = res; - return res; - } - - if (auto def = vals.getDefiningOp()) { - // Find broadcasted dimensions and replace indices for those dimensions - // with 0 (broadcasted dimension has always size 1). - SmallVector newIndices; - auto sourceTy = cast(def.getSrc().getType()); - auto targetTy = cast(def.getType()); - assert(sourceTy.getRank() == indices.size() && "Mismatched rank"); - for (int64_t i = 0; i < sourceTy.getRank(); ++i) { - if (sourceTy.getShape()[i] != targetTy.getShape()[i]) - newIndices.push_back( - rewriter.create(def.getLoc(), 0)); - else - newIndices.push_back(indices[i]); - } - // The original cache is only used for the original set of indices. - DenseMap tmpValMap; - Value res = - computeScalarValue(def.getSrc(), newIndices, rewriter, tmpValMap); - valMap[vals] = res; - return res; - } - - if (auto def = vals.getDefiningOp()) { - // Remove index at expanded dimension. - SmallVector newIndices = indices; - newIndices.erase(newIndices.begin() + def.getAxis()); - // The original cache is only used for the original set of indices. - DenseMap tmpValMap; - Value res = - computeScalarValue(def.getSrc(), newIndices, rewriter, tmpValMap); - valMap[vals] = res; - return res; - } - - if (auto def = vals.getDefiningOp()) { - // Permute indices. - SmallVector newIndices; - auto order = def.getOrder(); - assert(indices.size() == order.size() && "Mismatched rank"); - for (auto idx : order) - newIndices.push_back(indices[idx]); - // The original cache is only used for the original set of indices. - DenseMap tmpValMap; - Value res = - computeScalarValue(def.getSrc(), newIndices, rewriter, tmpValMap); - valMap[vals] = res; - return res; - } - - // Generic case where we copy defining op with scalar operands. - auto def = vals.getDefiningOp(); - OperationState newState(def->getLoc(), def->getName()); - for (auto op : def->getOperands()) { - newState.operands.push_back( - computeScalarValue(op, indices, rewriter, valMap)); - } - assert(def->getResults().size() == 1); - newState.types.push_back( - cast(def->getResultTypes()[0]).getElementType()); - newState.attributes = def->getAttrs(); - Value res = rewriter.create(newState)->getResult(0); - valMap[vals] = res; - return res; - } - Value extractMemRef(Location loc, Value ptr, ConversionPatternRewriter &rewriter) const { auto tensorTy = dyn_cast( @@ -333,215 +134,9 @@ struct MemoryOpConversion : public OpConversionPattern { return memRef; } - // Load scalar element from a temporary buffer or recompute it if the - // buffer doesn't exist. - Value computeOrLoadScalarValue(Value vals, Value tmpVals, ValueRange indices, - ConversionPatternRewriter &rewriter, - DenseMap &valMap) const { - // Allow null value for easier handling of optional arguments. - if (!vals) - return nullptr; - - // Load value from a temp buffer if any. - if (tmpVals) { - Value val = - rewriter.create(vals.getLoc(), tmpVals, indices); - // If we load a pointer then additional cast is needed because tensor of - // pointers is transformed into a vector of integers. - auto elemTy = dyn_cast(vals.getType()).getElementType(); - if (isa(elemTy)) - val = rewriter.create(vals.getLoc(), elemTy, val); - // We need to transform loaded i8 back to i1. - else if (elemTy.isInteger(1)) - val = rewriter.create(val.getLoc(), - rewriter.getI1Type(), val); - return val; - } - - return computeScalarValue(vals, indices, rewriter, valMap); - } - - LogicalResult scalarizeWithLoop(triton::LoadOp loadOp, - ConversionPatternRewriter &rewriter) const { - auto loc = loadOp.getLoc(); - auto vecTy = - dyn_cast(getTypeConverter()->convertType(loadOp.getType())); - - auto ptrs = loadOp.getPtr(); - auto mask = loadOp.getMask(); - auto other = loadOp.getOther(); - auto cache = loadOp.getCache(); - auto evict = loadOp.getEvict(); - auto isVolatile = loadOp.getIsVolatile(); - - // Create some reused constants. - Value zeroIdx = rewriter.create(loc, 0); - Value oneIdx = rewriter.create(loc, 1); - - // There is alloca_scope operation to control alloca scopes. But its usage - // in combination with nested SCF and multi-dimensional vectors make it - // impossible to lower scopes to LLVM using existing MLIR passes. For now, - // simply allocate temp memory in the function's region. - // TODO: Use alloc for big buffers and revisit alloca scoping. - Operation *allocaPoint = loadOp; - while (!isa(allocaPoint->getParentOp())) - allocaPoint = allocaPoint->getParentOp(); - - // Allocate temp buffer for the result. Write the other value there if - // we cannot write it in a loop. - auto resMemRefTy = - MemRefType::get(vecTy.getShape(), vecTy.getElementType()); - Value resMemRef = createAlloca(loc, resMemRefTy, allocaPoint, rewriter); - bool storeOtherInLoop = static_cast(mask); - if (other && !canComputeScalarValue(other)) { - SmallVector indices(vecTy.getRank(), zeroIdx); - rewriter.create( - loc, rewriter.getRemappedValue(other), resMemRef, indices); - storeOtherInLoop = false; - } - - // Store a tensor of pointers and mask into a temp buf if we can't - // compute them in a loop. - Value tmpPtrs = - maybeStoreVecToTempBuf(loc, ptrs, zeroIdx, allocaPoint, rewriter); - Value tmpMask = - maybeStoreVecToTempBuf(loc, mask, zeroIdx, allocaPoint, rewriter); - - // Create for-loops to iterate through all vector dimensions. - SmallVector forOps; - SmallVector ivs; - for (int64_t i = 0; i < vecTy.getRank(); ++i) { - Value upperBound = - rewriter.create(loc, vecTy.getShape()[i]); - auto forOp = - rewriter.create(loc, zeroIdx, upperBound, oneIdx); - forOps.push_back(forOp); - ivs.push_back(forOp.getInductionVar()); - rewriter.setInsertionPointToStart(forOp.getBody()); - } - - // Compute or load a scalar arguments. - DenseMap valMap; - Value scalarPtr = - computeOrLoadScalarValue(ptrs, tmpPtrs, ivs, rewriter, valMap); - Value scalarMask = - computeOrLoadScalarValue(mask, tmpMask, ivs, rewriter, valMap); - Value scalarOther; - if (storeOtherInLoop) { - if (other) { - scalarOther = computeScalarValue(other, ivs, rewriter, valMap); - } else { - scalarOther = rewriter.create( - loc, vecTy.getElementType(), - rewriter.getZeroAttr(vecTy.getElementType())); - } - } - - if (!mask) { - // Regular load case. - Value val = rewriter.create(loc, scalarPtr, cache, evict, - isVolatile); - rewriter.create(loc, val, resMemRef, ivs); - } else { - // Conditional load case - rewriter.create( - loc, scalarMask, - [&](OpBuilder &builder, Location loc) { - Value val = builder.create(loc, scalarPtr, cache, - evict, isVolatile); - builder.create(loc, val, resMemRef, ivs); - builder.create(loc); - }, - [&](OpBuilder &builder, Location loc) { - if (storeOtherInLoop) - builder.create(loc, scalarOther, resMemRef, ivs); - builder.create(loc); - }); - } - - // Load vector from the temp storage and return it from alloca scope. - rewriter.setInsertionPointAfter(forOps.front()); - SmallVector indices(vecTy.getRank(), zeroIdx); - Value res = - rewriter.create(loc, vecTy, resMemRef, indices); - - rewriter.replaceOp(loadOp, res); - return success(); - } - - LogicalResult scalarizeWithLoop(triton::StoreOp storeOp, - ConversionPatternRewriter &rewriter) const { - auto loc = storeOp.getLoc(); - auto vecTy = dyn_cast( - getTypeConverter()->convertType(storeOp.getValue().getType())); - - auto ptrs = storeOp.getPtr(); - auto mask = storeOp.getMask(); - auto vals = storeOp.getValue(); - auto cache = storeOp.getCache(); - auto evict = storeOp.getEvict(); - - // Create some reused constants. - Value zeroIdx = rewriter.create(loc, 0); - Value oneIdx = rewriter.create(loc, 1); - - // Alloca is inserted similar to the load case. - Operation *allocaPoint = storeOp; - while (!isa(allocaPoint->getParentOp())) - allocaPoint = allocaPoint->getParentOp(); - - // Store a tensor of pointers, mask, and values into a temp buf if we can't - // compute them in a loop. - Value tmpPtrs = - maybeStoreVecToTempBuf(loc, ptrs, zeroIdx, allocaPoint, rewriter); - Value tmpMask = - maybeStoreVecToTempBuf(loc, mask, zeroIdx, allocaPoint, rewriter); - Value tmpVals = - maybeStoreVecToTempBuf(loc, vals, zeroIdx, allocaPoint, rewriter); - - // Create for-loops to iterate through all vector dimensions. - SmallVector forOps; - SmallVector ivs; - for (int64_t i = 0; i < vecTy.getRank(); ++i) { - Value upperBound = - rewriter.create(loc, vecTy.getShape()[i]); - auto forOp = - rewriter.create(loc, zeroIdx, upperBound, oneIdx); - forOps.push_back(forOp); - ivs.push_back(forOp.getInductionVar()); - rewriter.setInsertionPointToStart(forOp.getBody()); - } - - // Compute or load scalar args. - DenseMap valMap; - Value scalarPtr = - computeOrLoadScalarValue(ptrs, tmpPtrs, ivs, rewriter, valMap); - Value scalarMask = - computeOrLoadScalarValue(mask, tmpMask, ivs, rewriter, valMap); - Value scalarVal = - computeOrLoadScalarValue(vals, tmpVals, ivs, rewriter, valMap); - - if (!mask) { - // Regular store case. - rewriter.create(loc, scalarPtr, scalarVal, cache, evict); - } else { - // Conditional store case - rewriter.create(loc, scalarMask, - [&](OpBuilder &builder, Location loc) { - builder.create( - loc, scalarPtr, scalarVal, cache, evict); - builder.create(loc); - }); - } - - rewriter.eraseOp(storeOp); - return success(); - } - protected: ModuleAxisInfoAnalysis &axisAnalysis; ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; - bool genScalarLoops; }; struct LoadOpConversion : public MemoryOpConversion { @@ -579,8 +174,8 @@ struct LoadOpConversion : public MemoryOpConversion { if (!triton::isTensorPointerType(ptr.getType())) { auto axisInfo = axisAnalysis.getAxisInfo(ptr); - if (axisInfo) { - return lowerUsingAxisInfo(axisInfo, loadOp, rewriter); + if (isContiguousRowMajorAccess(axisInfo, loadOp)) { + return lowerToContiguousRowMajor(loadOp, rewriter); } return lowerToScalarLoads(loadOp, rewriter); } @@ -607,8 +202,9 @@ struct LoadOpConversion : public MemoryOpConversion { return success(); } - LogicalResult lowerUsingAxisInfo(AxisInfo *axisInfo, triton::LoadOp loadOp, - ConversionPatternRewriter &rewriter) const { + LogicalResult + lowerToContiguousRowMajor(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { // This is an experimental code that covers only a simple case of axis info // usage to demostrate load by tensor of pointers transformation into vector // loads. @@ -618,53 +214,47 @@ struct LoadOpConversion : public MemoryOpConversion { auto vecTy = dyn_cast(getTypeConverter()->convertType(loadOp.getType())); auto shape = vecTy.getShape(); - auto contiguity = axisInfo->getContiguity(); - if (shape.back() > 1 && shape.back() == contiguity.back()) { - auto strides = computeStrides(shape); - int64_t numElems = vecTy.getNumElements(); - Type subVecTy = VectorType::get(shape.back(), vecTy.getElementType()); - Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); - Value mask = loadOp.getMask() - ? rewriter.getRemappedValue(loadOp.getMask()) - : nullptr; - Value zeroIdx = rewriter.create(loc, 0); - Value defaultVal = convertOtherVal(loadOp, rewriter); - Value res = defaultVal; - for (int64_t idx = 0; idx < numElems; idx += shape.back()) { - auto indices = delinearize(idx, strides); - SmallVector subIndices(indices.begin(), - indices.begin() + indices.size() - 1); - auto ptr = - extractScalarPointer(loc, loadOp.getPtr(), indices, rewriter); - Value memRef = - rewriter.create(loc, memRefTy, ptr); - Value vec; - if (mask) { - Value subMask = mask; - Value passThru = defaultVal; - if (shape.size() > 1) { - subMask = rewriter.create(loc, mask, subIndices); - passThru = - rewriter.create(loc, defaultVal, subIndices); - } - vec = rewriter.create( - loc, subVecTy, memRef, zeroIdx, subMask, passThru); - } else { - vec = rewriter.create(loc, subVecTy, memRef, zeroIdx); - } + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type subVecTy = VectorType::get(shape.back(), vecTy.getElementType()); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) + : nullptr; + Value zeroIdx = rewriter.create(loc, 0); + Value defaultVal = convertOtherVal(loadOp, rewriter); + Value res = defaultVal; + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + SmallVector subIndices(indices.begin(), + indices.begin() + indices.size() - 1); + auto ptr = extractScalarPointer(loc, loadOp.getPtr(), indices, rewriter); + Value memRef = + rewriter.create(loc, memRefTy, ptr); + Value vec; + if (mask) { + Value subMask = mask; + Value passThru = defaultVal; if (shape.size() > 1) { - res = rewriter.create(loc, vec, res, subIndices); - } else { - res = vec; + subMask = rewriter.create(loc, mask, subIndices); + passThru = + rewriter.create(loc, defaultVal, subIndices); } + vec = rewriter.create(loc, subVecTy, memRef, + zeroIdx, subMask, passThru); + } else { + vec = rewriter.create(loc, subVecTy, memRef, zeroIdx); } - rewriter.replaceOp(loadOp, res); - return success(); + if (shape.size() > 1) { + res = rewriter.create(loc, vec, res, subIndices); + } else { + res = vec; + } } - return lowerToScalarLoads(loadOp, rewriter); + rewriter.replaceOp(loadOp, res); + return success(); } LogicalResult lowerToScalarLoads(triton::LoadOp loadOp, @@ -677,12 +267,6 @@ struct LoadOpConversion : public MemoryOpConversion { auto vecTy = dyn_cast(getTypeConverter()->convertType(loadOp.getType())); - // We want to avoid a code explosion when scalarize loads of big vectors, - // so try to build a scalar loop. - if (genScalarLoops && vecTy.getNumElements() >= 16 && - succeeded(scalarizeWithLoop(loadOp, rewriter))) - return success(); - auto ptrs = rewriter.getRemappedValue(loadOp.getPtr()); auto mask = loadOp.getMask() ? rewriter.getRemappedValue(loadOp.getMask()) : nullptr; @@ -742,8 +326,8 @@ struct StoreOpConversion : public MemoryOpConversion { if (!triton::isTensorPointerType(ptr.getType())) { auto axisInfo = axisAnalysis.getAxisInfo(ptr); - if (axisInfo) { - return lowerUsingAxisInfo(axisInfo, storeOp, rewriter); + if (isContiguousRowMajorAccess(axisInfo, storeOp)) { + return lowerToContiguousRowMajor(storeOp, rewriter); } return lowerToScalarStores(storeOp, rewriter); } @@ -767,8 +351,9 @@ struct StoreOpConversion : public MemoryOpConversion { return success(); } - LogicalResult lowerUsingAxisInfo(AxisInfo *axisInfo, triton::StoreOp storeOp, - ConversionPatternRewriter &rewriter) const { + LogicalResult + lowerToContiguousRowMajor(triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { // This is an experimental code that covers only a simple case of axis info // usage to demostrate load by tensor of pointers transformation into vector // loads. @@ -778,44 +363,38 @@ struct StoreOpConversion : public MemoryOpConversion { auto vals = rewriter.getRemappedValue(storeOp.getValue()); auto vecTy = dyn_cast(vals.getType()); auto shape = vecTy.getShape(); - auto contiguity = axisInfo->getContiguity(); - if (shape.back() > 1 && shape.back() == contiguity.back()) { - auto strides = computeStrides(shape); - int64_t numElems = vecTy.getNumElements(); - Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); - Value mask = storeOp.getMask() - ? rewriter.getRemappedValue(storeOp.getMask()) - : nullptr; - Value zeroIdx = rewriter.create(loc, 0); - auto vals = rewriter.getRemappedValue(storeOp.getValue()); - for (int64_t idx = 0; idx < numElems; idx += shape.back()) { - auto indices = delinearize(idx, strides); - auto ptr = - extractScalarPointer(loc, storeOp.getPtr(), indices, rewriter); - Value memRef = - rewriter.create(loc, memRefTy, ptr); - indices.pop_back(); - auto val = rewriter.create(loc, vals, indices); - - if (mask) { - Value subMask = mask; - if (shape.size() > 1) { - SmallVector subIndices = indices; - subIndices.pop_back(); - subMask = rewriter.create(loc, mask, indices); - } - rewriter.create(loc, memRef, zeroIdx, subMask, - val); - } else { - rewriter.create(loc, val, memRef, zeroIdx); + + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = storeOp.getMask() + ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + Value zeroIdx = rewriter.create(loc, 0); + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + auto ptr = extractScalarPointer(loc, storeOp.getPtr(), indices, rewriter); + Value memRef = + rewriter.create(loc, memRefTy, ptr); + indices.pop_back(); + auto val = rewriter.create(loc, vals, indices); + + if (mask) { + Value subMask = mask; + if (shape.size() > 1) { + SmallVector subIndices = indices; + subIndices.pop_back(); + subMask = rewriter.create(loc, mask, indices); } + rewriter.create(loc, memRef, zeroIdx, subMask, + val); + } else { + rewriter.create(loc, val, memRef, zeroIdx); } - - rewriter.eraseOp(storeOp); - return success(); } - return lowerToScalarStores(storeOp, rewriter); + rewriter.eraseOp(storeOp); + return success(); } LogicalResult lowerToScalarStores(triton::StoreOp storeOp, @@ -827,12 +406,6 @@ struct StoreOpConversion : public MemoryOpConversion { auto loc = storeOp.getLoc(); auto tensorTy = dyn_cast(storeOp.getPtr().getType()); - // We want to avoid a code explosion when scalarize stores of big vectors, - // so try to build a scalar loop. - if (genScalarLoops && tensorTy.getNumElements() >= 16 && - succeeded(scalarizeWithLoop(storeOp, rewriter))) - return success(); - auto ptrs = rewriter.getRemappedValue(storeOp.getPtr()); auto mask = storeOp.getMask() ? rewriter.getRemappedValue(storeOp.getMask()) : nullptr; @@ -871,6 +444,46 @@ struct StoreOpConversion : public MemoryOpConversion { } }; +struct CpuStoreOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; + + LogicalResult + matchAndRewrite(triton::cpu::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp.getLoc(); + auto value = rewriter.getRemappedValue(storeOp.getSrc()); + auto memRef = storeOp.getDst(); + auto rank = dyn_cast(memRef.getType()).getRank(); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(rank, zeroIdx); + auto vecWrite = + rewriter.create(loc, value, memRef, + indices); //, inBounds); + rewriter.replaceOp(storeOp, vecWrite); + return success(); + } +}; + +struct CpuLoadOpConversion : public MemoryOpConversion { + using MemoryOpConversion::MemoryOpConversion; + + LogicalResult + matchAndRewrite(triton::cpu::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp.getLoc(); + auto memRef = loadOp.getSrc(); + auto rank = dyn_cast(memRef.getType()).getRank(); + auto resTy = dyn_cast( + getTypeConverter()->convertType(loadOp.getResult().getType())); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(resTy.getRank(), zeroIdx); + auto vecRead = + rewriter.create(loc, resTy, memRef, indices); + rewriter.replaceOp(loadOp, vecRead); + return success(); + } +}; + class MemoryOpConversionTarget : public ConversionTarget { public: explicit MemoryOpConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { @@ -882,6 +495,8 @@ class MemoryOpConversionTarget : public ConversionTarget { addLegalDialect(); addLegalOp(); + addIllegalOp(); + // Allow only scalar loads and stores. addDynamicallyLegalOp([](triton::LoadOp loadOp) { return loadOp.getType().isIntOrIndexOrFloat(); @@ -896,10 +511,6 @@ struct ConvertMemoryOps : public triton::cpu::impl::ConvertMemoryOpsBase { ConvertMemoryOps() = default; - ConvertMemoryOps(bool useScalarLoops) { - this->useScalarLoops = useScalarLoops; - } - void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); @@ -909,10 +520,9 @@ struct ConvertMemoryOps MemoryOpConversionTarget convTarget(*context); TritonToTritonCPUTypeConverter pointerConverter; RewritePatternSet patterns(context); - patterns.add(axisInfoAnalysis, shapeInfoAnalysis, - pointerConverter, useScalarLoops, context); - patterns.add(axisInfoAnalysis, shapeInfoAnalysis, - pointerConverter, useScalarLoops, context); + patterns.add(axisInfoAnalysis, shapeInfoAnalysis, + pointerConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); @@ -929,11 +539,6 @@ std::unique_ptr> createConvertMemoryOps() { return std::make_unique(); } -std::unique_ptr> -createConvertMemoryOps(bool useScalarLoops) { - return std::make_unique(useScalarLoops); -} - } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ScalarizeInterface.cpp b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeInterface.cpp new file mode 100644 index 000000000000..f194d3e195dd --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeInterface.cpp @@ -0,0 +1,277 @@ +#include "cpu/include/ScalarizePass/ScalarizeInterfaceImpl.h" + +#include "cpu/include/ScalarizePass/ScalarizeInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +using namespace mlir; + +#include "cpu/include/ScalarizePass/ScalarizeInterface.cpp.inc" + +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +Value mlir::triton::cpu::computeScalarValue(Operation *scalarizationOp, + Value vals, + ArrayRef indices, + PatternRewriter &rewriter) { + auto scalarized = cast(scalarizationOp); + return scalarized.computeScalarValue(vals, indices, rewriter); +} + +Value mlir::triton::cpu::computeScalarValue(Operation *scalarizationOp, + Value vals, ValueRange indices, + PatternRewriter &rewriter) { + auto scalarized = cast(scalarizationOp); + return scalarized.computeScalarValueForLoop(vals, indices, rewriter); +} + +bool mlir::triton::cpu::canComputeScalarValue(Value vals) { + auto def = vals.getDefiningOp(); + if (!def) + return false; + auto scalarized = dyn_cast(def); + if (!scalarized) + return false; + return scalarized.canComputeScalarValue(vals); +} + +namespace { + +namespace detail { + +template struct value_type_trait { + using type = typename T::value_type; +}; + +template <> struct value_type_trait { + using type = Value; +}; + +template +T createZeroIndex(mlir::Location loc, PatternRewriter &rewriter) { + llvm_unreachable("Default implementation should be overwritten."); +} + +template <> +int64_t createZeroIndex(mlir::Location loc, PatternRewriter &rewriter) { + return 0; +} + +template <> +Value createZeroIndex(mlir::Location loc, PatternRewriter &rewriter) { + return rewriter.create(loc, 0); +} + +} // namespace detail + +// Using ScalariztionFunctor class to partially specialize helper method +template struct ScalariztionFunctor { + template + static Value getScalarValue(OpTy operation, Value vals, T indices, + PatternRewriter &rewriter) { + auto def = vals.getDefiningOp(); + OperationState newState(def->getLoc(), def->getName()); + for (auto operand : def->getOperands()) { + newState.operands.push_back(computeScalarValue( + operand.getDefiningOp(), operand, indices, rewriter)); + } + assert(def->getResults().size() == 1 && + "[Unsupported] Opearation have multiple outputs."); + newState.types.push_back( + cast(def->getResultTypes()[0]).getElementType()); + newState.attributes = def->getAttrs(); + return rewriter.create(newState)->getResult(0); + } +}; + +/// External model implementation of ScalarizeInterface for TritonOps. An +/// external model implementation is used for now till the use of +/// `ScalarizeInterface` is on-par with the current ScalarizeUsingForOp. This +/// allows to register this Interface for all required ops depending on it's +/// type. +template +struct TritonOpScalarizeInterface + : public ScalarizeInterface::ExternalModel, + OpTy> { + bool canComputeScalarValue(Operation *op, Value vals) const { + for (auto operand : op->getOperands()) { + if (isa(operand)) { + return false; + } + auto scalarized = dyn_cast(operand.getDefiningOp()); + if (!scalarized) { + return false; + } + if (!scalarized.canComputeScalarValue(operand)) { + return false; + } + } + return true; + } + + Value computeScalarValue(Operation *op, Value vals, ArrayRef indices, + PatternRewriter &rewriter) const { + OpTy def = vals.getDefiningOp(); + return ScalariztionFunctor().getScalarValue(def, vals, indices, + rewriter); + } + + Value computeScalarValueForLoop(Operation *op, Value vals, ValueRange indices, + PatternRewriter &rewriter) const { + OpTy def = vals.getDefiningOp(); + return ScalariztionFunctor().getScalarValue(def, vals, indices, + rewriter); + } +}; +template <> struct ScalariztionFunctor { + template + Value getScalarValue(SplatOp def, Value vals, T indices, + PatternRewriter &rewriter) { + + return def.getSrc(); + } +}; + +template <> +bool TritonOpScalarizeInterface::canComputeScalarValue( + Operation *op, Value vals) const { + return true; +} + +template <> +struct TritonOpScalarizeInterface + : public ScalarizeInterface::ExternalModel< + TritonOpScalarizeInterface, MakeRangeOp> { + + bool canComputeScalarValue(Operation *op, Value vals) const { return true; } + + Value computeScalarValue(Operation *op, Value vals, ArrayRef indices, + PatternRewriter &rewriter) const { + MakeRangeOp def = vals.getDefiningOp(); + int32_t start = static_cast(def.getStart()); + assert(indices.size() == 1); + Type elemTy = cast(def.getType()).getElementType(); + return rewriter.create( + def.getLoc(), elemTy, + rewriter.getIntegerAttr(elemTy, start + indices[0])); + } + + Value computeScalarValueForLoop(Operation *op, Value vals, ValueRange indices, + PatternRewriter &rewriter) const { + MakeRangeOp def = vals.getDefiningOp(); + assert(indices.size() == 1); + int32_t start = static_cast(def.getStart()); + Type elemTy = cast(def.getType()).getElementType(); + Value startVal = rewriter.create( + def.getLoc(), elemTy, rewriter.getIntegerAttr(elemTy, start)); + Value index = indices[0]; + if (!elemTy.isIndex()) + index = + rewriter.create(def.getLoc(), elemTy, index); + return rewriter.create(def.getLoc(), elemTy, startVal, + index); + } +}; + +template <> struct ScalariztionFunctor { + template + Value getScalarValue(BroadcastOp operation, Value vals, T indices, + PatternRewriter &rewriter) { + BroadcastOp def = operation; + using UnderlyingIndicesType = typename detail::value_type_trait::type; + // Find broadcasted dimensions and replace indices for those + // dimensions with 0 (broadcasted dimension has always size 1). + SmallVector newIndices; + auto sourceTy = cast(def.getSrc().getType()); + auto targetTy = cast(def.getType()); + assert(sourceTy.getRank() == indices.size() && "Mismatched rank"); + for (int64_t i = 0; i < sourceTy.getRank(); ++i) { + if (sourceTy.getShape()[i] != targetTy.getShape()[i]) + newIndices.push_back(detail::createZeroIndex( + std::move(def.getLoc()), rewriter)); + else + newIndices.push_back(indices[i]); + } + Value src = def.getSrc(); + return computeScalarValue(src.getDefiningOp(), src, newIndices, rewriter); + } +}; + +template <> struct ScalariztionFunctor { + template + Value getScalarValue(ExpandDimsOp def, Value vals, T indices, + PatternRewriter &rewriter) { + using UnderlyingIndicesType = typename detail::value_type_trait::type; + // Remove index at expanded dimension. + SmallVector newIndices(indices); + newIndices.erase(newIndices.begin() + def.getAxis()); + Value src = def.getSrc(); + return computeScalarValue(src.getDefiningOp(), src, newIndices, rewriter); + } +}; + +template <> struct ScalariztionFunctor { + template + Value getScalarValue(arith::ConstantOp def, Value vals, T indices, + PatternRewriter &rewriter) { + auto denseVal = cast(def.getValue()); + assert(denseVal.isSplat()); + auto scalarAttr = denseVal.getSplatValue(); + Value res = rewriter.create( + def.getLoc(), scalarAttr.getType(), scalarAttr); + return res; + } +}; + +template <> +bool TritonOpScalarizeInterface::canComputeScalarValue( + Operation *op, Value vals) const { + auto cst = static_cast(op); + if (auto denseVal = dyn_cast(cst.getValue())) { + return denseVal.isSplat(); + } + return false; +} + +template <> struct ScalariztionFunctor { + template + Value getScalarValue(TransOp def, Value vals, T indices, + PatternRewriter &rewriter) { + + using UnderlyingIndicesType = typename detail::value_type_trait::type; + + // Permute indices. + SmallVector newIndices; + auto order = def.getOrder(); + assert(indices.size() == order.size() && "Mismatched rank"); + for (auto idx : order) + newIndices.push_back(indices[idx]); + Value src = def.getSrc(); + return computeScalarValue(src.getDefiningOp(), src, newIndices, rewriter); + } +}; + +} // namespace + +template static void registerOne(MLIRContext *ctx) { + OpType::template attachInterface>(*ctx); +} + +template static void registerAll(MLIRContext *ctx) { + (registerOne(ctx), ...); +} + +void mlir::triton::cpu::registerTritonOpScalarizeExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, TritonDialect *dialect) { + registerAll(ctx); + }); + registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) { + registerAll(ctx); + }); +} diff --git a/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp new file mode 100644 index 000000000000..dbe7a31d3f86 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp @@ -0,0 +1,387 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include "cpu/include/ScalarizePass/ScalarizeInterface.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_SCALARIZEUSINGFOROP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +template +struct ScalarizeOpConversion : public OpRewritePattern { + + ScalarizeOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, + MLIRContext *context) + : OpRewritePattern(context), axisAnalysis(axisInfoAnalysis) {} + + Value createAlloca(Location loc, MemRefType ty, Operation *before, + PatternRewriter &rewriter) const { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(before); + return rewriter.create( + loc, ty, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); + } + + // If tensor is not null and its element cannot be recomputed in a scalar + // loop, then store it to a temporary buffer. + Value storeIfNonScalarizable(Location loc, Value vals, Value zeroIdx, + Operation *allocaPoint, + PatternRewriter &rewriter) const { + // To skip optional values and scalarizable value, that can be computed + // inside loop + if (!vals || canComputeScalarValue(vals)) + return nullptr; + + auto tensor = vals; + auto tensorTy = cast(vals.getType()); + auto elemTy = tensorTy.getElementType(); + if (isa(elemTy)) { + elemTy = IntegerType::get(elemTy.getContext(), 64); + } + // Memref of i1 assumes one element per byte when we load/store element, + // but vector store (through transfer write) would write 1 bit per element. + if (elemTy.isInteger(1)) { + elemTy = rewriter.getI8Type(); + tensor = rewriter.create( + loc, + RankedTensorType::get(tensorTy.getShape(), elemTy, + tensorTy.getEncoding()), + tensor); + } + auto memRefTy = MemRefType::get(tensorTy.getShape(), elemTy); + Value memRef = createAlloca(vals.getLoc(), memRefTy, allocaPoint, rewriter); + SmallVector indices(tensorTy.getRank(), zeroIdx); + rewriter.create(vals.getLoc(), tensor, memRef); + return memRef; + } + + // Load scalar element from a temporary buffer or recompute it if the + // buffer doesn't exist. + Value loadOrComputeScalarValue(Value vals, Value tmpVals, ValueRange indices, + PatternRewriter &rewriter) const { + // Allow null value for easier handling of optional arguments. + if (!vals) + return nullptr; + + // If nothing loaded, value should be scalar computable + if (!tmpVals) { + if (!canComputeScalarValue(vals)) { + llvm::errs() + << "Passed value was not loaded and can't be computed as scalar: " + << vals << "\n"; + llvm::report_fatal_error("Cannot proceed such value"); + return nullptr; + } + return computeScalarValue(vals.getDefiningOp(), vals, indices, rewriter); + } + + // Load value from a temp buffer if any. + Value val = + rewriter.create(vals.getLoc(), tmpVals, indices); + // If we load a pointer then additional cast is needed because tensor of + // pointers is transformed into a vector of integers. + auto elemTy = dyn_cast(vals.getType()).getElementType(); + if (isa(elemTy)) + val = rewriter.create(vals.getLoc(), elemTy, val); + // We need to transform loaded i8 back to i1. + else if (elemTy.isInteger(1)) + val = rewriter.create(val.getLoc(), rewriter.getI1Type(), + val); + return val; + } + + // This is core methods that generates SCF::For + // We are checking arguments and results of operation + // to scalarize them if possible and load/store if they are dynamical + LogicalResult scalarizeWithLoop(OpTy scalarizeOp, + PatternRewriter &rewriter) const { + llvm_unreachable("nope"); + return failure(); + } + + // Method that describes how to check arguments and results of operation + // for scalarization + bool shouldScalarizeOp(OpTy scalarizeOp) const { + llvm_unreachable("nope"); + return false; + } + + // code for Memory Ops, as requires getPtr method + bool shouldScalarizeOpGeneric(OpTy scalarizeOp) const { + + auto ptr = scalarizeOp.getPtr(); + if (triton::isTensorPointerType(ptr.getType())) { + return false; + } + + auto axisInfo = axisAnalysis.getAxisInfo(ptr); + if (isContiguousRowMajorAccess(axisInfo, scalarizeOp)) { + return false; + } + + // Scalar memory ops and boundary checks are not expected. + if (!scalarizeOp.getBoundaryCheck().empty()) { + return false; + } + + return ScalarizeOpConversion::shouldScalarizeOp(scalarizeOp); + } + + LogicalResult matchAndRewrite(OpTy scalarOp, + PatternRewriter &rewriter) const override { + + // We want to avoid a code explosion when scalarize loads of big vectors, + // so try to build a scalar loop. + if (shouldScalarizeOpGeneric(scalarOp) && + succeeded(scalarizeWithLoop(scalarOp, rewriter))) + return success(); + return failure(); + } + +protected: + ModuleAxisInfoAnalysis &axisAnalysis; +}; + +template <> +LogicalResult ScalarizeOpConversion::scalarizeWithLoop( + triton::StoreOp storeOp, PatternRewriter &rewriter) const { + auto loc = storeOp.getLoc(); + + auto ptrs = storeOp.getPtr(); + auto mask = storeOp.getMask(); + auto vals = storeOp.getValue(); + auto cache = storeOp.getCache(); + auto evict = storeOp.getEvict(); + + auto tensorTy = cast(vals.getType()); + + // Create some reused constants. + Value zeroIdx = rewriter.create(loc, 0); + Value oneIdx = rewriter.create(loc, 1); + + // Alloca is inserted similar to the load case. + Operation *allocaPoint = storeOp; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Store a tensor of pointers, mask, and values into a temp buf if we can't + // compute them in a loop. + Value tmpPtrs = + storeIfNonScalarizable(loc, ptrs, zeroIdx, allocaPoint, rewriter); + Value tmpMask = + storeIfNonScalarizable(loc, mask, zeroIdx, allocaPoint, rewriter); + Value tmpVals = + storeIfNonScalarizable(loc, vals, zeroIdx, allocaPoint, rewriter); + + // Create for-loops to iterate through all vector dimensions. + SmallVector forOps; + SmallVector ivs; + for (int64_t i = 0; i < tensorTy.getRank(); ++i) { + Value upperBound = + rewriter.create(loc, tensorTy.getShape()[i]); + auto forOp = rewriter.create(loc, zeroIdx, upperBound, oneIdx); + forOps.push_back(forOp); + ivs.push_back(forOp.getInductionVar()); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + // Compute or load scalar args. + Value scalarPtr = loadOrComputeScalarValue(ptrs, tmpPtrs, ivs, rewriter); + Value scalarMask = loadOrComputeScalarValue(mask, tmpMask, ivs, rewriter); + Value scalarVal = loadOrComputeScalarValue(vals, tmpVals, ivs, rewriter); + + if (!mask) { + // Regular store case. + auto store_op = rewriter.create(loc, scalarPtr, scalarVal, + cache, evict); + } else { + // Conditional store case + rewriter.create(loc, scalarMask, + [&](OpBuilder &builder, Location loc) { + builder.create( + loc, scalarPtr, scalarVal, cache, evict); + builder.create(loc); + }); + } + + rewriter.eraseOp(storeOp); + return success(); +} + +template <> +bool ScalarizeOpConversion::shouldScalarizeOp( + triton::StoreOp scalarOp) const { + + if (!isa(scalarOp.getValue().getType())) { + return false; + } + + auto tensorTy = cast(scalarOp.getPtr().getType()); + return tensorTy.getNumElements() >= 16; +} + +template <> +LogicalResult ScalarizeOpConversion::scalarizeWithLoop( + triton::LoadOp loadOp, PatternRewriter &rewriter) const { + auto loc = loadOp.getLoc(); + auto tensorTy = cast(loadOp.getType()); + + auto ptrs = loadOp.getPtr(); + auto mask = loadOp.getMask(); + auto other = loadOp.getOther(); + auto cache = loadOp.getCache(); + auto evict = loadOp.getEvict(); + auto isVolatile = loadOp.getIsVolatile(); + + // Create some reused constants. + Value zeroIdx = rewriter.create(loc, 0); + Value oneIdx = rewriter.create(loc, 1); + + // There is alloca_scope operation to control alloca scopes. But its usage + // in combination with nested SCF and multi-dimensional vectors make it + // impossible to lower scopes to LLVM using existing MLIR passes. For now, + // simply allocate temp memory in the function's region. + // TODO: Use alloc for big buffers and revisit alloca scoping. + Operation *allocaPoint = loadOp; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Allocate temp buffer for the result. Write the other value there if + // we cannot write it in a loop. + auto resMemRefTy = + MemRefType::get(tensorTy.getShape(), tensorTy.getElementType()); + Value resMemRef = createAlloca(loc, resMemRefTy, allocaPoint, rewriter); + bool storeOtherInLoop = static_cast(mask); + if (other && !canComputeScalarValue(other)) { + rewriter.create(loc, other, resMemRef); + storeOtherInLoop = false; + } + + // Store a tensor of pointers and mask into a temp buf if we can't + // compute them in a loop. + Value tmpPtrs = + storeIfNonScalarizable(loc, ptrs, zeroIdx, allocaPoint, rewriter); + Value tmpMask = + storeIfNonScalarizable(loc, mask, zeroIdx, allocaPoint, rewriter); + + // Create for-loops to iterate through all vector dimensions. + SmallVector forOps; + SmallVector ivs; + for (int64_t i = 0; i < tensorTy.getRank(); ++i) { + Value upperBound = + rewriter.create(loc, tensorTy.getShape()[i]); + auto forOp = rewriter.create(loc, zeroIdx, upperBound, oneIdx); + forOps.push_back(forOp); + ivs.push_back(forOp.getInductionVar()); + rewriter.setInsertionPointToStart(forOp.getBody()); + } + + // Compute or load a scalar arguments. + Value scalarPtr = loadOrComputeScalarValue(ptrs, tmpPtrs, ivs, rewriter); + Value scalarMask = loadOrComputeScalarValue(mask, tmpMask, ivs, rewriter); + Value scalarOther; + if (storeOtherInLoop) { + if (other) { + scalarOther = + computeScalarValue(other.getDefiningOp(), other, ivs, rewriter); + } else { + scalarOther = rewriter.create( + loc, tensorTy.getElementType(), + rewriter.getZeroAttr(tensorTy.getElementType())); + } + } + + if (!mask) { + // Regular load case. + Value val = rewriter.create(loc, scalarPtr, cache, evict, + isVolatile); + rewriter.create(loc, val, resMemRef, ivs); + } else { + // Conditional load case + rewriter.create( + loc, scalarMask, + [&](OpBuilder &builder, Location loc) { + Value val = builder.create(loc, scalarPtr, cache, + evict, isVolatile); + builder.create(loc, val, resMemRef, ivs); + builder.create(loc); + }, + [&](OpBuilder &builder, Location loc) { + if (storeOtherInLoop) + builder.create(loc, scalarOther, resMemRef, ivs); + builder.create(loc); + }); + } + + // Load vector from the temp storage and return it from alloca scope. + rewriter.setInsertionPointAfter(forOps.front()); + SmallVector indices(tensorTy.getRank(), zeroIdx); + Value res = rewriter.create(loc, tensorTy, resMemRef); + rewriter.replaceOp(loadOp, res); + return success(); +} + +template <> +bool ScalarizeOpConversion::shouldScalarizeOp( + triton::LoadOp scalarOp) const { + if (!isa(scalarOp.getType())) { + return false; + } + auto tensorTy = cast(scalarOp.getType()); + return tensorTy.getNumElements() >= 16; +} + +struct ScalarizeUsingForOpPass + : public triton::impl::ScalarizeUsingForOpBase { + using ScalarizeUsingForOpBase::ScalarizeUsingForOpBase; + + ScalarizeUsingForOpPass() : ScalarizeUsingForOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + RewritePatternSet patterns(context); + patterns.add, + ScalarizeOpConversion>(axisInfoAnalysis, + context); + + if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) { + return signalPassFailure(); + } + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createScalarizeUsingForOpPass() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 065d740bca79..97d17952f4fa 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -1,3 +1,4 @@ +#include "ScalarizePass/ScalarizeInterfaceImpl.h" #include "TritonCPUToLLVM/Passes.h" #include "TritonCPUTransforms/Passes.h" #include "TritonToTritonCPU/Passes.h" @@ -27,10 +28,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { .value("libsleef", cpu::VecLib::Sleef) .value("libmvec", cpu::VecLib::Mvec); - m.def("add_convert_memory_ops", - [](mlir::PassManager &pm, bool useScalarLoops) { - pm.addPass(mlir::triton::cpu::createConvertMemoryOps(useScalarLoops)); - }); + m.def("add_scalarize", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createScalarizeUsingForOpPass()); + }); + m.def("add_convert_memory_ops", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); + }); m.def("add_convert_ptr_ops", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertPtrOps()); }); @@ -147,6 +150,7 @@ void init_triton_cpu(py::module &&m) { mlir::DialectRegistry registry; registry.insert(); + mlir::triton::cpu::registerTritonOpScalarizeExternalModels(registry); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); From 358d3af0b70525ecd738508f1523b19f9f4f829a Mon Sep 17 00:00:00 2001 From: Junyi Mei Date: Tue, 15 Oct 2024 00:11:27 +0800 Subject: [PATCH 122/165] Lower memory ops with vector gather and scatter (#158) * Lower memory ops with vector gather and scatter This commit add lowerToGather and lowerToScatter for load and store conversion. Memory ops with the pointer computed from splat and addptr can be lowered with vector.gather or vector.scatter. For architectures with scatter and gather support (like SVE and RVV), the code generated with this approach might be more efficient. Two options are added to scalarization and memory op conversion to enable lowering with gather and scatter operations. Signed-off-by: Junyi Mei * Fix incorrect rank and type in gather and scatter Signed-off-by: Junyi Mei * Lower store op with 1-D vector scatter ops Signed-off-by: Junyi Mei --------- Signed-off-by: Junyi Mei --- test/TritonCPU/convert-memory-ops.mlir | 27 ++--- third_party/cpu/backend/compiler.py | 4 +- .../cpu/include/TritonToTritonCPU/Passes.h | 36 ++++++ .../cpu/include/TritonToTritonCPU/Passes.td | 13 ++ .../TritonToTritonCPU/ConvertMemoryOps.cpp | 113 +++++++++++++++++- .../ScalarizeUsingForOps.cpp | 30 ++++- third_party/cpu/triton_cpu.cc | 10 +- 7 files changed, 199 insertions(+), 34 deletions(-) diff --git a/test/TritonCPU/convert-memory-ops.mlir b/test/TritonCPU/convert-memory-ops.mlir index c98747269fdc..710b76279610 100644 --- a/test/TritonCPU/convert-memory-ops.mlir +++ b/test/TritonCPU/convert-memory-ops.mlir @@ -1,18 +1,10 @@ -// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops | FileCheck %s +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-memory-ops=use-gather-scatter=true -cse | FileCheck %s -// Convert strided masked loads to scalar loads. +// Convert strided masked loads to gather. // CHECK-LABEL: @strided_masked_loads -// CHECK: %[[COND:.+]] = vector.extract %[[MASK:.+]][[[#IDX:]]] : i1 -// CHECK-NEXT: scf.if %[[COND]] -> (vector<32xi32>) { -// CHECK-NEXT: %[[PTR:.+]] = vector.extract %[[IN:.+]][[[#IDX]]] : i64 from vector<32xi64> -// CHECK-NEXT: %[[PTR_:.+]] = tt.int_to_ptr %[[PTR]] : i64 -> !tt.ptr -// CHECK-NEXT: %[[VAL:.+]] = tt.load %[[PTR_]] : !tt.ptr -// CHECK-NEXT: %[[NEW_OUT:.+]] = vector.insert %[[VAL]], %[[OUT:.+]] [[[#IDX]]] : i32 into vector<32xi32> -// CHECK-NEXT: scf.yield %[[NEW_OUT]] : vector<32xi32> -// CHECK-NEXT: } else { -// CHECK-NEXT: scf.yield %[[OUT]] : vector<32xi32> -// CHECK-NEXT: } +// CHECK: %[[PTR:.+]] = triton_cpu.ptr_to_memref %[[BASE:.+]] : -> memref +// CHECK: %[[VAL:.+]] = vector.gather %[[PTR]][] [%[[INDEX_VEC:.+]]], %[[MASK:.+]], %[[OTHER:.+]] : memref, vector<32xi32>, vector<32xi1>, vector<32xi32> into vector<32xi32> module { tt.func public @strided_masked_loads(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { @@ -36,16 +28,11 @@ module { // ----- -// Convert strided masked stores to scalar stores. +// Convert strided masked stores to scatter. // CHECK-LABEL: @strided_masked_stores -// CHECK: %[[COND:.+]] = vector.extract %[[MASK:.+]][[[#IDX:]]] : i1 from vector<32xi1> -// CHECK-NEXT: scf.if %[[COND]] { -// CHECK-NEXT: %[[PTR:.+]] = vector.extract %[[OUT:.+]][[[#IDX]]] : i64 from vector<32xi64> -// CHECK-NEXT: %[[PTR_:.+]] = tt.int_to_ptr %[[PTR]] : i64 -> !tt.ptr -// CHECK-NEXT: %[[VAL:.+]] = vector.extract %[[IN:.+]][[[#IDX]]] : i32 from vector<32xi32> -// CHECK-NEXT: tt.store %[[PTR_]], %[[VAL]] : !tt.ptr -// CHECK-NEXT: } +// CHECK: %[[PTR:.+]] = triton_cpu.ptr_to_memref %[[BASE:.+]] : -> memref +// CHECK: vector.scatter %[[PTR]][] [%[[INDEX_VEC:.+]]], %[[MASK:.+]], %[[VALS:.+]] : memref, vector<32xi32>, vector<32xi1>, vector<32xi32> module { tt.func public @strided_masked_stores(%arg0: !tt.ptr {tt.divisibility = 16 : i32} ) { diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 991d694e854c..af366f937547 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -122,8 +122,8 @@ def make_ttcir(mod, metadata, opt): # TTIR -> TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() - cpu.passes.ttcpuir.add_scalarize(pm) - cpu.passes.ttcpuir.add_convert_memory_ops(pm) + cpu.passes.ttcpuir.add_scalarize(pm, True) + cpu.passes.ttcpuir.add_convert_memory_ops(pm, True) cpu.passes.ttcpuir.add_convert_ptr_ops(pm) cpu.passes.ttcpuir.add_convert_elementwise_ops(pm) cpu.passes.ttcpuir.add_convert_elem_manip_ops(pm) diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index a84b69c4b754..cd0babee3de4 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -2,6 +2,7 @@ #define TRITONTOTRITONCPU_CONVERSION_PASSES_H #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Analysis/AxisInfo.h" @@ -23,6 +24,8 @@ namespace cpu { std::unique_ptr> createConvertElementwiseOps(); std::unique_ptr> createConvertElemManipOps(); std::unique_ptr> createConvertMemoryOps(); +std::unique_ptr> +createConvertMemoryOps(bool useGatherScatter); std::unique_ptr> createConvertPtrOps(); std::unique_ptr> createConvertDotOp(); std::unique_ptr> createConvertControlFlowOps(); @@ -35,6 +38,8 @@ std::unique_ptr> createConvertAtomicOps(); std::unique_ptr> createConvertDebugOps(); std::unique_ptr> createScalarizeUsingForOpPass(); +std::unique_ptr> +createScalarizeUsingForOpPass(bool skipGatherScatter); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonToTritonCPU/Passes.h.inc" @@ -84,6 +89,37 @@ bool isContiguousRowMajorAccess(AxisInfo *axisInfo, OpTy op) { return (shape.back() > 1 && shape.back() == contiguity.back()); } +// Get the base pointer and offset of a memory operation if the pointer is +// defined by a SplatOp and an AddPtrOp. +template , bool> = true> +std::tuple getMemoryBaseOffset(OpTy op) { + Value ptr = op.getPtr(); + + auto addPtrOp = ptr.getDefiningOp(); + if (!addPtrOp) + return std::make_tuple(nullptr, nullptr); + + Value basePtr = nullptr; + Value offset = nullptr; + + if (auto splatOp = addPtrOp->getOperand(0).getDefiningOp()) { + if (isa(splatOp.getOperand().getType())) { + basePtr = splatOp.getOperand(); + offset = addPtrOp.getOperand(1); + } + } + + if (auto splatOp = addPtrOp->getOperand(1).getDefiningOp()) { + if (!basePtr && isa(splatOp.getOperand().getType())) { + basePtr = splatOp.getOperand(); + offset = addPtrOp.getOperand(0); + } + } + + return std::make_tuple(basePtr, offset); +} + } // namespace cpu } // namespace triton diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 230731249783..8def195cc220 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -8,6 +8,13 @@ def ConvertMemoryOps : Pass<"triton-cpu-convert-memory-ops", "mlir::ModuleOp"> { let description = [{ }]; + + let options = [ + Option<"useGatherScatter", "use-gather-scatter", + "bool", /*default*/"false", + "Use Gather or Scatter to lower memory ops.">, + ]; + let constructor = "mlir::triton::cpu::createConvertMemoryOps()"; let dependentDialects = ["mlir::arith::ArithDialect", @@ -173,6 +180,12 @@ def ScalarizeUsingForOp : Pass<"triton-cpu-scalarize", "mlir::ModuleOp"> { the amount of IR without any further optimization. }]; + let options = [ + Option<"skipGatherScatter", "skip-gather-scatter", + "bool", /*default*/"false", + "Skip scalarizing gather/scatter ops.">, + ]; + let constructor = "mlir::triton::cpu::createScalarizeUsingForOpPass()"; let dependentDialects = ["mlir::arith::ArithDialect", diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp index 2ca6becadcfe..51729ca6618f 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertMemoryOps.cpp @@ -45,9 +45,12 @@ struct MemoryOpConversion : public OpConversionPattern { MemoryOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, - TypeConverter &typeConverter, MLIRContext *context) + TypeConverter &typeConverter, MLIRContext *context, + bool useGatherScatter) : OpConversionPattern(typeConverter, context), - axisAnalysis(axisInfoAnalysis), shapeAnalysis(shapeInfoAnalysis) {} + axisAnalysis(axisInfoAnalysis), shapeAnalysis(shapeInfoAnalysis) { + this->useGatherScatter = useGatherScatter; + } Value extractScalarPointer(Location loc, Value ptrs, ArrayRef indices, @@ -137,6 +140,7 @@ struct MemoryOpConversion : public OpConversionPattern { protected: ModuleAxisInfoAnalysis &axisAnalysis; ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis; + bool useGatherScatter; }; struct LoadOpConversion : public MemoryOpConversion { @@ -177,6 +181,9 @@ struct LoadOpConversion : public MemoryOpConversion { if (isContiguousRowMajorAccess(axisInfo, loadOp)) { return lowerToContiguousRowMajor(loadOp, rewriter); } + if (useGatherScatter && succeeded(lowerToGather(loadOp, rewriter))) { + return success(); + } return lowerToScalarLoads(loadOp, rewriter); } @@ -257,6 +264,44 @@ struct LoadOpConversion : public MemoryOpConversion { return success(); } + LogicalResult lowerToGather(triton::LoadOp loadOp, + ConversionPatternRewriter &rewriter) const { + auto loc = loadOp.getLoc(); + auto vecTy = dyn_cast( + getTypeConverter()->convertType(loadOp.getResult().getType())); + auto shape = vecTy.getShape(); + + auto [basePtr, offset] = getMemoryBaseOffset(loadOp); + + if (!basePtr || !offset) + return failure(); + + auto pointeeType = + dyn_cast(basePtr.getType()).getPointeeType(); + + auto gatherBase = rewriter.create( + loc, MemRefType::get({}, pointeeType), basePtr); + auto gatherIndices = SmallVector(); + auto gatherIndexVec = rewriter.getRemappedValue(offset); + + Value gatherMask; + if (auto loadMask = loadOp.getMask()) { + gatherMask = rewriter.getRemappedValue(loadMask); + } else { + auto maskType = VectorType::get(shape, rewriter.getI1Type()); + gatherMask = rewriter.create( + loc, maskType, DenseElementsAttr::get(maskType, true)); + } + + auto passThru = convertOtherVal(loadOp, rewriter); + + auto gatherOp = + rewriter.create(loc, vecTy, gatherBase, gatherIndices, + gatherIndexVec, gatherMask, passThru); + rewriter.replaceOp(loadOp, gatherOp); + return success(); + } + LogicalResult lowerToScalarLoads(triton::LoadOp loadOp, ConversionPatternRewriter &rewriter) const { // Scalar loads and boundary checks are not expected. @@ -329,6 +374,9 @@ struct StoreOpConversion : public MemoryOpConversion { if (isContiguousRowMajorAccess(axisInfo, storeOp)) { return lowerToContiguousRowMajor(storeOp, rewriter); } + if (useGatherScatter && succeeded(lowerToScatter(storeOp, rewriter))) { + return success(); + } return lowerToScalarStores(storeOp, rewriter); } @@ -397,6 +445,55 @@ struct StoreOpConversion : public MemoryOpConversion { return success(); } + LogicalResult lowerToScatter(triton::StoreOp storeOp, + ConversionPatternRewriter &rewriter) const { + auto loc = storeOp.getLoc(); + auto vals = rewriter.getRemappedValue(storeOp.getValue()); + auto vecTy = dyn_cast(vals.getType()); + auto shape = vecTy.getShape(); + + auto [basePtr, offset] = getMemoryBaseOffset(storeOp); + + if (!basePtr || !offset) + return failure(); + + auto strides = computeStrides(shape); + int64_t numElems = vecTy.getNumElements(); + Type memRefTy = MemRefType::get(shape.back(), vecTy.getElementType()); + Value mask = storeOp.getMask() + ? rewriter.getRemappedValue(storeOp.getMask()) + : nullptr; + + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + indices.pop_back(); + + auto val = rewriter.create(loc, vals, indices); + auto indexVec = rewriter.create( + loc, rewriter.getRemappedValue(offset), indices); + Value scatterMask; + + if (mask) { + scatterMask = rewriter.create(loc, mask, indices); + } else { + // Create a mask with all true values if no mask is provided. + auto maskType = VectorType::get({shape.back()}, rewriter.getI1Type()); + scatterMask = rewriter.create( + loc, maskType, DenseElementsAttr::get(maskType, true)); + } + + auto scatterBase = rewriter.create( + loc, MemRefType::get({}, vecTy.getElementType()), basePtr); + auto scatterIndices = SmallVector(); + + rewriter.create(loc, scatterBase, scatterIndices, + indexVec, scatterMask, val); + } + + rewriter.eraseOp(storeOp); + return success(); + } + LogicalResult lowerToScalarStores(triton::StoreOp storeOp, ConversionPatternRewriter &rewriter) const { // Scalar stores and boundary checks are not expected. @@ -511,6 +608,10 @@ struct ConvertMemoryOps : public triton::cpu::impl::ConvertMemoryOpsBase { ConvertMemoryOps() = default; + ConvertMemoryOps(bool useGatherScatter) { + this->useGatherScatter = useGatherScatter; + } + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); @@ -522,7 +623,8 @@ struct ConvertMemoryOps RewritePatternSet patterns(context); patterns.add(axisInfoAnalysis, shapeInfoAnalysis, - pointerConverter, context); + pointerConverter, context, + useGatherScatter); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure(); @@ -539,6 +641,11 @@ std::unique_ptr> createConvertMemoryOps() { return std::make_unique(); } +std::unique_ptr> +createConvertMemoryOps(bool useGatherScatter) { + return std::make_unique(useGatherScatter); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp index dbe7a31d3f86..0e8102831e1e 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp @@ -15,8 +15,10 @@ namespace mlir { namespace triton { +namespace cpu { #define GEN_PASS_DEF_SCALARIZEUSINGFOROP #include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace cpu } // namespace triton } // namespace mlir @@ -30,8 +32,10 @@ template struct ScalarizeOpConversion : public OpRewritePattern { ScalarizeOpConversion(ModuleAxisInfoAnalysis &axisInfoAnalysis, - MLIRContext *context) - : OpRewritePattern(context), axisAnalysis(axisInfoAnalysis) {} + MLIRContext *context, bool skipGatherScatter) + : OpRewritePattern(context), axisAnalysis(axisInfoAnalysis) { + this->skipGatherScatter = skipGatherScatter; + } Value createAlloca(Location loc, MemRefType ty, Operation *before, PatternRewriter &rewriter) const { @@ -138,6 +142,11 @@ struct ScalarizeOpConversion : public OpRewritePattern { return false; } + auto [basePtr, offset] = getMemoryBaseOffset(scalarizeOp); + if (skipGatherScatter && basePtr && offset) { + return false; + } + // Scalar memory ops and boundary checks are not expected. if (!scalarizeOp.getBoundaryCheck().empty()) { return false; @@ -159,6 +168,7 @@ struct ScalarizeOpConversion : public OpRewritePattern { protected: ModuleAxisInfoAnalysis &axisAnalysis; + bool skipGatherScatter; }; template <> @@ -351,11 +361,16 @@ bool ScalarizeOpConversion::shouldScalarizeOp( } struct ScalarizeUsingForOpPass - : public triton::impl::ScalarizeUsingForOpBase { + : public triton::cpu::impl::ScalarizeUsingForOpBase< + ScalarizeUsingForOpPass> { using ScalarizeUsingForOpBase::ScalarizeUsingForOpBase; ScalarizeUsingForOpPass() : ScalarizeUsingForOpBase() {} + ScalarizeUsingForOpPass(bool skipGatherScatter) : ScalarizeUsingForOpBase() { + this->skipGatherScatter = skipGatherScatter; + } + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); @@ -363,8 +378,8 @@ struct ScalarizeUsingForOpPass ModuleAxisInfoAnalysis axisInfoAnalysis(mod); RewritePatternSet patterns(context); patterns.add, - ScalarizeOpConversion>(axisInfoAnalysis, - context); + ScalarizeOpConversion>( + axisInfoAnalysis, context, skipGatherScatter); if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) { return signalPassFailure(); @@ -382,6 +397,11 @@ std::unique_ptr> createScalarizeUsingForOpPass() { return std::make_unique(); } +std::unique_ptr> +createScalarizeUsingForOpPass(bool skipGatherScatter) { + return std::make_unique(skipGatherScatter); +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 97d17952f4fa..9c4ca64e90b2 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -28,11 +28,13 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { .value("libsleef", cpu::VecLib::Sleef) .value("libmvec", cpu::VecLib::Mvec); - m.def("add_scalarize", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createScalarizeUsingForOpPass()); + m.def("add_scalarize", [](mlir::PassManager &pm, bool skip_gather_scatter) { + pm.addPass( + mlir::triton::cpu::createScalarizeUsingForOpPass(skip_gather_scatter)); }); - m.def("add_convert_memory_ops", [](mlir::PassManager &pm) { - pm.addPass(mlir::triton::cpu::createConvertMemoryOps()); + m.def("add_convert_memory_ops", [](mlir::PassManager &pm, + bool use_gather_scatter) { + pm.addPass(mlir::triton::cpu::createConvertMemoryOps(use_gather_scatter)); }); m.def("add_convert_ptr_ops", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertPtrOps()); From d80f30a516e30051d8e82a511a15c2004fdd0623 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 17 Oct 2024 11:19:16 -0500 Subject: [PATCH 123/165] Introduce DotOp lowering to AMX (#157) * Add DotOp lowering to AMX operations. Signed-off-by: Ilya Enkovich * Support direct tiles store to output memory. Signed-off-by: Ilya Enkovich * Add lit tests for amx. Signed-off-by: Ilya Enkovich * Fix review comments. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- bin/RegisterTritonDialects.h | 7 +- test/TritonCPU/dot-to-amx.mlir | 250 +++++ third_party/cpu/CMakeLists.txt | 2 +- third_party/cpu/backend/compiler.py | 15 + .../cpu/include/TritonCPUTransforms/Passes.h | 4 + .../cpu/include/TritonCPUTransforms/Passes.td | 27 + .../lib/TritonCPUTransforms/CMakeLists.txt | 1 + .../TritonCPUTransforms/ConvertDotToAMX.cpp | 910 ++++++++++++++++++ third_party/cpu/triton_cpu.cc | 27 +- 9 files changed, 1237 insertions(+), 6 deletions(-) create mode 100644 test/TritonCPU/dot-to-amx.mlir create mode 100644 third_party/cpu/lib/TritonCPUTransforms/ConvertDotToAMX.cpp diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index dee36aa7fa92..85f17c611b3e 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -27,6 +27,7 @@ #include "triton/Conversion/TritonToTritonGPU/Passes.h" #include "triton/Target/LLVMIR/Passes.h" +#include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/InitAllPasses.h" @@ -86,9 +87,9 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, mlir::memref::MemRefDialect, mlir::vector::VectorDialect, - mlir::tensor::TensorDialect, mlir::gpu::GPUDialect, - mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect, - mlir::triton::nvgpu::NVGPUDialect, + mlir::amx::AMXDialect, mlir::tensor::TensorDialect, + mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, + mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect, mlir::triton::amdgpu::TritonAMDGPUDialect, mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect>(); } diff --git a/test/TritonCPU/dot-to-amx.mlir b/test/TritonCPU/dot-to-amx.mlir new file mode 100644 index 000000000000..19c476403d2e --- /dev/null +++ b/test/TritonCPU/dot-to-amx.mlir @@ -0,0 +1,250 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-dot-to-amx="convert-bf16=true convert-fp16=true convert-i8=true" -canonicalize | FileCheck %s + +// Replacement of a contraction operation with a single tile_mulf operation. + +// CHECK-LABEL: @test_single_mulf +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<16x32xbf16> +// CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x16xf32, strided<[16, 1]>> +// CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index +// CHECK: %[[ACC:.+]] = amx.tile_zero : vector<16x16xf32> +// CHECK-NEXT: %[[LHS:.+]] = amx.tile_load %3[%4#0, %4#1] +// CHECK-NEXT: %[[RHS:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] +// CHECK-NEXT: %[[RES:.+]] = amx.tile_mulf %[[LHS]], %[[RHS]], %[[ACC]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES]] : memref<16x16xf32, strided<[16, 1]>>, vector<16x16xf32> + +#loc = loc(unknown) +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + tt.func public @test_single_mulf(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf32> loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c32_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c16_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3 = triton_cpu.extract_memref %0 : > -> memref<16x32xbf16, strided<[32, 1]>> loc(#loc) + %4:2 = triton_cpu.extract_indices %0 : > -> index, index loc(#loc) + %5 = vector.transfer_read %3[%4#0, %4#1], %cst {in_bounds = [true, true]} : memref<16x32xbf16, strided<[32, 1]>>, vector<16x32xbf16> loc(#loc) + %6 = triton_cpu.extract_memref %1 : > -> memref<32x16xbf16, strided<[16, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<32x16xbf16, strided<[16, 1]>>, vector<32x16xbf16> loc(#loc) + %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %5, %8, %cst_0 : vector<16x32xbf16>, vector<32x16xbf16> into vector<16x16xf32> loc(#loc) + %10 = triton_cpu.extract_memref %2 : > -> memref<16x16xf32, strided<[16, 1]>> loc(#loc) + %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[16, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// Replacement of a contraction operation with multiple tile_muli operations. + +// CHECK-LABEL: @test_single_tile_two_muli +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<32x64xi8> +// CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x16xi32, strided<[16, 1]>> +// CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index +// CHECK: %[[ACC:.+]] = amx.tile_zero : vector<16x16xi32> +// CHECK-NEXT: %[[LHS1:.+]] = amx.tile_load %3[%4#0, %4#1] +// CHECK-NEXT: %[[RHS1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] +// CHECK-NEXT: %[[RES1:.+]] = amx.tile_muli %[[LHS1]], %[[RHS1]], %[[ACC]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> +// CHECK-NEXT: %[[IDX1:.+]] = arith.addi %4#1, %c64{{.*}} : index +// CHECK-NEXT: %[[LHS2:.+]] = amx.tile_load %3[%4#0, %[[IDX1]]] : memref<16x128xi8, strided<[128, 1]>> into vector<16x64xi8> +// CHECK-NEXT: %[[RHS2:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xi8> into vector<16x64xi8> +// CHECK-NEXT: %[[RES2:.+]] = amx.tile_muli %[[LHS2]], %[[RHS2]], %[[RES1]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES2]] : memref<16x16xi32, strided<[16, 1]>>, vector<16x16xi32> + +#loc = loc(unknown) +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + tt.func public @test_single_tile_two_muli(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %c0_i8 = arith.constant 0 : i8 loc(#loc) + %cst = arith.constant dense<0> : vector<16x16xi32> loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c128_i64 = arith.constant 128 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c128_i64], [%c128_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c128_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c16_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3 = triton_cpu.extract_memref %0 : > -> memref<16x128xi8, strided<[128, 1]>> loc(#loc) + %4:2 = triton_cpu.extract_indices %0 : > -> index, index loc(#loc) + %5 = vector.transfer_read %3[%4#0, %4#1], %c0_i8 {in_bounds = [true, true]} : memref<16x128xi8, strided<[128, 1]>>, vector<16x128xi8> loc(#loc) + %6 = triton_cpu.extract_memref %1 : > -> memref<128x16xi8, strided<[16, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %c0_i8 {in_bounds = [true, true]} : memref<128x16xi8, strided<[16, 1]>>, vector<128x16xi8> loc(#loc) + %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %5, %8, %cst : vector<16x128xi8>, vector<128x16xi8> into vector<16x16xi32> loc(#loc) + %10 = triton_cpu.extract_memref %2 : > -> memref<16x16xi32, strided<[16, 1]>> loc(#loc) + %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32, strided<[16, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// Replacement of a contraction operation with multiple tile_mulf operations +// and multiple output tiles. + +// CHECK-LABEL: @test_two_tiles_four_mulf +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<32x64xbf16> +// CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x32xf32, strided<[32, 1]>> +// CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index +// CHECK: %[[ACC1:.+]] = amx.tile_zero : vector<16x16xf32> +// CHECK-NEXT: %[[ACC2:.+]] = amx.tile_zero : vector<16x16xf32> +// CHECK-NEXT: %[[LHS1:.+]] = amx.tile_load %3[%4#0, %4#1] : memref<16x64xbf16, strided<[64, 1]>> into vector<16x32xbf16> +// CHECK-NEXT: %[[RHS1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RES1:.+]] = amx.tile_mulf %[[LHS1]], %[[RHS1]], %[[ACC1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK: %[[RHS2:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RES2:.+]] = amx.tile_mulf %[[LHS1]], %[[RHS2]], %[[ACC2]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK: %[[IDX1:.+]] = arith.addi %4#1, %c32{{.*}} : index +// CHECK-NEXT: %[[LHS2:.+]] = amx.tile_load %3[%4#0, %[[IDX1]]] : memref<16x64xbf16, strided<[64, 1]>> into vector<16x32xbf16> +// CHECK: %[[RHS3:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RES3:.+]] = amx.tile_mulf %[[LHS2]], %[[RHS3]], %[[RES1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES3]] : memref<16x32xf32, strided<[32, 1]>>, vector<16x16xf32> +// CHECK: %[[RHS4:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RES4:.+]] = amx.tile_mulf %[[LHS2]], %[[RHS4]], %[[RES2]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK: %[[IDX2:.+]] = arith.addi %[[OUT_INDICES]]#1, %c16{{.*}} : index +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[IDX2]]], %[[RES4]] : memref<16x32xf32, strided<[32, 1]>>, vector<16x16xf32> + +#loc = loc(unknown) +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + tt.func public @test_two_tiles_four_mulf(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x32xf32> loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c64_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c16_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3 = triton_cpu.extract_memref %0 : > -> memref<16x64xbf16, strided<[64, 1]>> loc(#loc) + %4:2 = triton_cpu.extract_indices %0 : > -> index, index loc(#loc) + %5 = vector.transfer_read %3[%4#0, %4#1], %cst {in_bounds = [true, true]} : memref<16x64xbf16, strided<[64, 1]>>, vector<16x64xbf16> loc(#loc) + %6 = triton_cpu.extract_memref %1 : > -> memref<64x32xbf16, strided<[32, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<64x32xbf16, strided<[32, 1]>>, vector<64x32xbf16> loc(#loc) + %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %5, %8, %cst_0 : vector<16x64xbf16>, vector<64x32xbf16> into vector<16x32xf32> loc(#loc) + %10 = triton_cpu.extract_memref %2 : > -> memref<16x32xf32, strided<[32, 1]>> loc(#loc) + %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x32xf32>, memref<16x32xf32, strided<[32, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// More complicated case with a loop, input casts, and accumulator that +// cannot fit tile register file. + +// CHECK-LABEL: @test_loop_acc_two_blocks +// CHECK: %[[LHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<64x64xbf16> +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<32x64xbf16> +// CHECK: %[[ACC_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<64x32xf32> +// CHECK: vector.transfer_write %cst{{.+}}, %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}] {in_bounds = [true, true]} : vector<64x32xf32>, memref<64x32xf32> +// CHECK: %3:2 = scf.for %arg3 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg4 = %0, %arg5 = %1) -> (!tt.ptr>, !tt.ptr>) : i32 +// CHECK: %[[LHS:.+]] = vector.transfer_read %{{.+}}[%{{.+}}#0, %{{.+}}#1], %{{.+}} {in_bounds = [true, true]} : memref<64x128xf8E5M2, strided<[128, 1]>>, vector<64x64xf8E5M2> +// CHECK: %[[RHS:.+]] = vector.transfer_read %{{.+}}[%{{.+}}#0, %{{.+}}#1], %{{.+}} {in_bounds = [true, true]} : memref<128x32xf8E5M2, strided<[32, 1]>>, vector<64x32xf8E5M2> +// CHECK-NEXT: %[[LHS1:.+]] = arith.extf %[[LHS]] : vector<64x64xf8E5M2> to vector<64x64xbf16> +// CHECK-NEXT: vector.transfer_write %[[LHS1]], %[[LHS_BUF]][%c0{{.*}}, %c0{{.*}}] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16> +// CHECK-NEXT: %[[RHS1:.+]] = arith.extf %[[RHS]] : vector<64x32xf8E5M2> to vector<64x32xbf16> +// CHECK-COUNT-32: vector.store %{{.+}}, %[[RHS_BUF]][%{{.+}}, %{{.+}}] : memref<32x64xbf16>, vector<64xbf16> +// CHECK-NEXT: %[[ACC_0_0:.+]] = amx.tile_load %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<64x32xf32> into vector<16x16xf32> +// CHECK-NEXT: %[[ACC_0_1:.+]] = amx.tile_load %[[ACC_BUF]][%c0{{.*}}, %c16{{.*}}] : memref<64x32xf32> into vector<16x16xf32> +// CHECK-NEXT: %[[ACC_1_0:.+]] = amx.tile_load %[[ACC_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<64x32xf32> into vector<16x16xf32> +// CHECK-NEXT: %[[ACC_1_1:.+]] = amx.tile_load %[[ACC_BUF]][%c16{{.*}}, %c16{{.*}}] : memref<64x32xf32> into vector<16x16xf32> +// CHECK-NEXT: %[[LHS_0_0:.+]] = amx.tile_load %[[LHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[LHS_1_0:.+]] = amx.tile_load %[[LHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RHS_0_0:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[TMP_0_0:.+]] = amx.tile_mulf %[[LHS_0_0]], %[[RHS_0_0]], %[[ACC_0_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: %[[TMP_1_0:.+]] = amx.tile_mulf %[[LHS_1_0]], %[[RHS_0_0]], %[[ACC_1_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: %[[RHS_0_1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[TMP_0_1:.+]] = amx.tile_mulf %[[LHS_0_0]], %[[RHS_0_1]], %[[ACC_0_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: %[[TMP_1_1:.+]] = amx.tile_mulf %[[LHS_1_0]], %[[RHS_0_1]], %[[ACC_1_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: %[[LHS_0_1:.+]] = amx.tile_load %[[LHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[LHS_1_1:.+]] = amx.tile_load %[[LHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RHS_1_0:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RES_0_0:.+]] = amx.tile_mulf %[[LHS_0_1]], %[[RHS_1_0]], %[[TMP_0_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}], %[[RES_0_0]] : memref<64x32xf32>, vector<16x16xf32> +// CHECK-NEXT: %[[RES_1_0:.+]] = amx.tile_mulf %[[LHS_1_1]], %[[RHS_1_0]], %[[TMP_1_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c16{{.*}}, %c0{{.*}}], %[[RES_1_0]] : memref<64x32xf32>, vector<16x16xf32> +// CHECK-NEXT: %[[RHS_1_1:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RES_0_1:.+]] = amx.tile_mulf %[[LHS_0_1]], %[[RHS_1_1]], %[[TMP_0_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c0{{.*}}, %c16{{.*}}], %[[RES_0_1]] : memref<64x32xf32>, vector<16x16xf32> +// CHECK-NEXT: %[[RES_1_1:.+]] = amx.tile_mulf %[[LHS_1_1]], %[[RHS_1_1]], %[[TMP_1_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c16{{.*}}, %c16{{.*}}], %[[RES_1_1]] : memref<64x32xf32>, vector<16x16xf32> +// CHECK-NEXT: %[[ACC_2_0:.+]] = amx.tile_load %[[ACC_BUF]][%c32{{.*}}, %c0{{.*}}] : memref<64x32xf32> into vector<16x16xf32> +// CHECK-NEXT: %[[ACC_2_1:.+]] = amx.tile_load %[[ACC_BUF]][%c32{{.*}}, %c16{{.*}}] : memref<64x32xf32> into vector<16x16xf32> +// CHECK-NEXT: %[[ACC_3_0:.+]] = amx.tile_load %[[ACC_BUF]][%c48{{.*}}, %c0{{.*}}] : memref<64x32xf32> into vector<16x16xf32> +// CHECK-NEXT: %[[ACC_3_1:.+]] = amx.tile_load %[[ACC_BUF]][%c48{{.*}}, %c16{{.*}}] : memref<64x32xf32> into vector<16x16xf32> +// CHECK-NEXT: %[[LHS_2_0:.+]] = amx.tile_load %[[LHS_BUF]][%c32{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[LHS_3_0:.+]] = amx.tile_load %[[LHS_BUF]][%c48{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RHS_0_0:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[TMP_2_0:.+]] = amx.tile_mulf %[[LHS_2_0]], %[[RHS_0_0]], %[[ACC_2_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: %[[TMP_3_0:.+]] = amx.tile_mulf %[[LHS_3_0]], %[[RHS_0_0]], %[[ACC_3_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: %[[RHS_0_1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[TMP_2_1:.+]] = amx.tile_mulf %[[LHS_2_0]], %[[RHS_0_1]], %[[ACC_2_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: %[[TMP_3_1:.+]] = amx.tile_mulf %[[LHS_3_0]], %[[RHS_0_1]], %[[ACC_3_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: %[[LHS_2_1:.+]] = amx.tile_load %[[LHS_BUF]][%c32{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[LHS_3_1:.+]] = amx.tile_load %[[LHS_BUF]][%c48{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RHS_1_0:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RES_2_0:.+]] = amx.tile_mulf %[[LHS_2_1]], %[[RHS_1_0]], %[[TMP_2_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c32{{.*}}, %c0{{.*}}], %[[RES_2_0]] : memref<64x32xf32>, vector<16x16xf32> +// CHECK-NEXT: %[[RES_3_0:.+]] = amx.tile_mulf %[[LHS_3_1]], %[[RHS_1_0]], %[[TMP_3_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c48{{.*}}, %c0{{.*}}], %[[RES_3_0]] : memref<64x32xf32>, vector<16x16xf32> +// CHECK-NEXT: %[[RHS_1_1:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> +// CHECK-NEXT: %[[RES_2_1:.+]] = amx.tile_mulf %[[LHS_2_1]], %[[RHS_1_1]], %[[TMP_2_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c32{{.*}}, %c16{{.*}}], %[[RES_2_1]] : memref<64x32xf32>, vector<16x16xf32> +// CHECK-NEXT: %[[RES_3_1:.+]] = amx.tile_mulf %[[LHS_3_1]], %[[RHS_1_1]], %[[TMP_3_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c48{{.*}}, %c16{{.*}}], %[[RES_3_1]] : memref<64x32xf32>, vector<16x16xf32> +// CHECK: %[[RES:.+]] = vector.transfer_read %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<64x32xf32>, vector<64x32xf32> + +#loc = loc(unknown) +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +module { + tt.func public @test_loop_acc_two_blocks(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : f8E5M2 loc(#loc) + %c2_i32 = arith.constant 2 : i32 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %c64_i32 = arith.constant 64 : i32 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<64x32xf32> loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c128_i64 = arith.constant 128 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c128_i64], [%c128_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c128_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c64_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3:3 = scf.for %arg3 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg4 = %cst_0, %arg5 = %0, %arg6 = %1) -> (vector<64x32xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %6 = triton_cpu.extract_memref %arg5 : > -> memref<64x128xf8E5M2, strided<[128, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %arg5 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<64x128xf8E5M2, strided<[128, 1]>>, vector<64x64xf8E5M2> loc(#loc) + %9 = triton_cpu.extract_memref %arg6 : > -> memref<128x32xf8E5M2, strided<[32, 1]>> loc(#loc) + %10:2 = triton_cpu.extract_indices %arg6 : > -> index, index loc(#loc) + %11 = vector.transfer_read %9[%10#0, %10#1], %cst {in_bounds = [true, true]} : memref<128x32xf8E5M2, strided<[32, 1]>>, vector<64x32xf8E5M2> loc(#loc) + %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %8, %11, %arg4 : vector<64x64xf8E5M2>, vector<64x32xf8E5M2> into vector<64x32xf32> loc(#loc) + %13 = tt.advance %arg5, [%c0_i32, %c64_i32] : > loc(#loc) + %14 = tt.advance %arg6, [%c64_i32, %c0_i32] : > loc(#loc) + scf.yield %12, %13, %14 : vector<64x32xf32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %4 = triton_cpu.extract_memref %2 : > -> memref<64x32xf32, strided<[32, 1]>> loc(#loc) + %5:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %3#0, %4[%5#0, %5#1] {in_bounds = [true, true]} : vector<64x32xf32>, memref<64x32xf32, strided<[32, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 2acf7a6b6f48..59d0f5c53d46 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -4,7 +4,7 @@ add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) - target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm) + target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation) endif() add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index af366f937547..2e346a259977 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -78,6 +78,14 @@ def __init__(self, target: tuple) -> None: self.cpu_arch = llvm.get_cpu_tripple().split("-")[0] self.cpu_name = llvm.get_cpu_name() self.cpu_features = llvm.get_cpu_features() + if 'amx-tile' in self.cpu_features: + if not cpu.enable_amx(): + import warnings + warnings.warn("Warning! Couldn't enable AMX for the process. AMX optimizations are disabled.") + self.cpu_features.discard('amx-tile') + self.cpu_features.discard('amx-int8') + self.cpu_features.discard('amx-fp16') + self.cpu_features.discard('amx-bf16') def parse_options(self, opts) -> Any: args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts} @@ -151,6 +159,13 @@ def make_tttcir(self, mod, metadata, opt): if convert_bf16_dot_product: use_horizontal_sum = os.getenv("TRITON_CPU_DOT_PROD_HORIZ_SUM", "1") == "1" cpu.passes.ttcpuir.add_convert_dot_product(pm, use_horizontal_sum) + if 'amx-tile' in self.cpu_features: + amx_int8 = 'amx-int8' in self.cpu_features + # amx_fp16 = 'amx-fp16' in self.cpu_features + # FP16 support is not in AMX dialect yet + amx_fp16 = False + amx_bf16 = 'amx-bf16' in self.cpu_features + cpu.passes.ttcpuir.add_convert_dot_to_amx(pm, amx_int8, amx_fp16, amx_bf16) promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features # We don't have any lowering for mixed precision matmuls, so always use casts for now convert_mixed_precision_matmul = True diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index ec4c10498891..d9c121cb219b 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -33,6 +33,10 @@ std::unique_ptr> createConvertDotProduct(); std::unique_ptr> createConvertDotProduct(bool useHorizontalSum); +std::unique_ptr> createConvertDotToAMX(); +std::unique_ptr> +createConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16); + #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUTransforms/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index 656eff4f4fe7..c337d2c92eec 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -98,4 +98,31 @@ def ConvertDotProduct : Pass<"triton-cpu-convert-dot-product", "mlir::ModuleOp"> "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertDotToAMX : Pass<"triton-cpu-convert-dot-to-amx", "mlir::ModuleOp"> { + let summary = "Convert dot product op to AMX dialect."; + let description = [{ + This pass is used to lower matmul operations to amx dialect. + }]; + + let options = [ + Option<"convertInt8", "convert-i8", + "bool", /*default*/"false", + "Use AMX extensions for int8 type.">, + Option<"convertFp16", "convert-fp16", + "bool", /*default*/"false", + "Use AMX extensions for ifp16 type.">, + Option<"convertBf16", "convert-bf16", + "bool", /*default*/"false", + "Use AMX extensions for bf16 type.">, + ]; + + let constructor = "mlir::triton::cpu::createConvertDotToAMX()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::amx::AMXDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt index 86277b3f0490..3bf2e3568238 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_triton_library(TritonCPUTransforms ConvertDotProduct.cpp + ConvertDotToAMX.cpp ConvertUnsupportedOps.cpp DecomposeFpConversions.cpp OptimizeMasks.cpp diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotToAMX.cpp new file mode 100644 index 000000000000..aacf150de3b1 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotToAMX.cpp @@ -0,0 +1,910 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "include/triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTDOTTOAMX +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +#define DEBUG_TYPE "triton-cpu-dot-to-amx" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// This struct describes buffers used to load/store AMX tiles. +struct AmxBuffer { + Value memRef; + SmallVector indices; + + bool empty() const { return !memRef; } +}; + +// This structure is used to hold candidates for conversion to AMX +// Mul[F|I]Op operations. +struct AmxDotOpCandidate { + // Operation to convert. + vector::ContractionOp op; + // Available LHS, RHS, and accumulator types are limited in AMX and we might + // require additional casts. Here we keep actual element types used by LHS, + // RHS, and accumulator in AMX tiles. + Type lhsTileElemTy; + Type rhsTileElemTy; + Type accTileElemTy; + // AMX tile row size is limited by 64 bytes, so M and N dimensions are limited + // by 16 because accumulator always has 4-byte elements. K dimension for tiles + // is limited by 64 / . Here we keep actual tile sizes. + int64_t tileM; + int64_t tileN; + int64_t tileK; + // We have a limited number of available tiles, so if input/output is too + // big to fit available tiles, we need to split them into blocks. Here we + // keep a number of tiles in accumulator block. K dimension for input blocks + // is always 1 tile now. + int64_t tilesInBlockM; + int64_t tilesInBlockN; + // If accumulator is updated in a loop, then this flag indicates if we + // should keep it in tiles the whole loop and move back to vectors only + // after the loop. + bool keepAccOnTiles = false; + // If we want to keep accumulator in tiles but it's too big, then we might + // keep it bufferized instead. + bool keepAccInBuf = false; + // If resulting tiles are not required to be trasfered to vectors and can be + // directly stored to the output memory instead, then this field holds a + // buffer to use. + AmxBuffer outBuf; + // If output buffer is used then keep the original vector store here. + Operation *origStore = nullptr; +}; + +bool checkIdxMap(Attribute attr, unsigned int v1, unsigned int v2) { + auto map = cast(attr).getAffineMap(); + return map == + AffineMap::getMultiDimMapWithTargets(3, {v1, v2}, attr.getContext()); +} + +// Return true if specified contraction op is actually a converted DotOp. +bool isDotOp(vector::ContractionOp op) { + // First, check ranks of inputs. + if (cast(op.getLhs().getType()).getRank() != 2 || + cast(op.getRhs().getType()).getRank() != 2 || + cast(op.getAcc().getType()).getRank() != 2) { + LDBG("Drop candidate with rank != 2"); + return false; + } + + // Matmul uses add as a combining function. + if (op.getKind() != vector::CombiningKind::ADD) { + LDBG("Drop candidate with combining function " << op.getKind()); + return false; + } + + // Expect two parallel and one reduction iterators. + auto iterTypes = op.getIteratorTypes(); + if (iterTypes.size() != 3 || + cast(iterTypes[0]).getValue() != + vector::IteratorType::parallel || + cast(iterTypes[1]).getValue() != + vector::IteratorType::parallel || + cast(iterTypes[2]).getValue() != + vector::IteratorType::reduction) { + LDBG("Drop candidate with mismatched iterator types."); + return false; + } + + // Check affine maps. + // TODO: be less restrictive on maps to allow transposed inputs? + auto idxMaps = op.getIndexingMaps(); + if (!checkIdxMap(idxMaps[0], 0, 2) || !checkIdxMap(idxMaps[1], 2, 1) || + !checkIdxMap(idxMaps[2], 0, 1)) { + LDBG("Drop candidate with mismatched affine maps."); + return false; + } + + return true; +} + +// Check if input and output types can be handled by AMX (possibly, using +// additional casts for input/output). Returns true if AMX usage is possible. +// In this case, tile element type fields of the candidate structure are +// filled with actual types to be used in lowering. +bool checkElemTypes(Type lhsElemTy, Type rhsElemTy, Type accElemTy, + Type resElemTy, bool supportInt8, bool supportFp16, + bool supportBf16, AmxDotOpCandidate &candidate) { + MLIRContext *ctx = lhsElemTy.getContext(); + if (lhsElemTy.isInteger()) { + if (!supportInt8) { + LDBG("Drop candidate because AMX_INT8 is not available."); + return false; + } + + // For integer case only i8 is allowed for LHS and RHS. + if (!lhsElemTy.isInteger(8) || !rhsElemTy.isInteger(8)) { + LDBG("Drop candidate with unsupported input integer type."); + return false; + } + + // Accumulator should be i32. If it's smaller, we will use casts. + if (!accElemTy.isInteger() || accElemTy.getIntOrFloatBitWidth() > 32 || + !resElemTy.isInteger() || resElemTy.getIntOrFloatBitWidth() > 32) { + LDBG("Drop candidate with unsupported output integer type."); + return false; + } + + candidate.lhsTileElemTy = IntegerType::get(ctx, 8); + candidate.rhsTileElemTy = IntegerType::get(ctx, 8); + candidate.accTileElemTy = IntegerType::get(ctx, 32); + + return true; + } + + // FP case. Expect no integer args or result. + if (rhsElemTy.isInteger() || accElemTy.isInteger() || resElemTy.isInteger()) { + LDBG("Drop candidate with mixed int/fp types."); + return false; + } + + // For fp case LHS and RHS types should match and can be either FP16 or + // BF16. + if (lhsElemTy.getIntOrFloatBitWidth() > 16 || + rhsElemTy.getIntOrFloatBitWidth() > 16) { + LDBG("Drop candidate with unsupported input fp type."); + return false; + } + + // Try to find a common input type. There is currently no support + // for FP8 types, so promote them to FP16/BF16. + Type commonInputElemTy; + if (lhsElemTy.getIntOrFloatBitWidth() == 16) { + commonInputElemTy = lhsElemTy; + if (rhsElemTy.getIntOrFloatBitWidth() == 16 && + rhsElemTy != commonInputElemTy) { + LDBG("Drop candidate with mismatched input types."); + return false; + } + } else if (rhsElemTy.getIntOrFloatBitWidth() == 16) + commonInputElemTy = rhsElemTy; + // Both inputs are FP8, choose 16-bit FP type to use. + else if (supportBf16) + commonInputElemTy = BFloat16Type::get(ctx); + else + commonInputElemTy = Float16Type::get(ctx); + + if (commonInputElemTy.isF16() && !supportFp16) { + LDBG("Drop candidate because AMX_FP16 is not available."); + return false; + } + + if (commonInputElemTy.isBF16() && !supportBf16) { + LDBG("Drop candidate because AMX_BF16 is not available."); + return false; + } + + // Accumulator type should be FP32, we can use casts if it is smaller. + if (accElemTy.getIntOrFloatBitWidth() > 32) { + LDBG("Drop candidate with unsupported accumulator type."); + return false; + } + + candidate.lhsTileElemTy = commonInputElemTy; + candidate.rhsTileElemTy = commonInputElemTy; + candidate.accTileElemTy = Float32Type::get(ctx); + + return true; +} + +// Check if accumulator value is updated in a loop and has no other +// usages than a dot op, that updates it. Tile loads/stores and casts +// for such accumulators can be done outside of the loop. +bool isLoopCarriedAcc(Value acc) { + LDBG("Check if accumulator can be held in tiles: " << acc); + if (!acc.hasOneUse()) { + LDBG(" No. Has multiple uses."); + for (auto op : acc.getUsers()) + LDBG(" " << *op); + return false; + } + + auto blockArg = dyn_cast(acc); + if (!blockArg) { + LDBG(" No. Not a block argument."); + return false; + } + + auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); + if (!forOp) { + LDBG(" No. Not in a for-loop."); + return false; + } + + blockArg.getArgNumber(); + + Value updAcc = acc.getUsers().begin()->getResult(0); + if (!updAcc.hasOneUse()) { + LDBG(" No. Has multiple uses."); + return false; + } + + auto &updAccUse = *updAcc.getUses().begin(); + if (!isa(updAccUse.getOwner()) || + updAccUse.getOperandNumber() != + (blockArg.getArgNumber() - forOp.getNumInductionVars())) { + LDBG(" No. Loop carried dependency not detected."); + return false; + } + + LDBG(" Yes."); + return true; +} + +// Return a value that holds the resulting loop carried accumulator value. +// It's one of ForOp's results. +Value getResValueForLoopCarriedAcc(vector::ContractionOp op) { + Value updAcc = op.getResult(); + auto forOp = dyn_cast(op->getParentOp()); + auto &use = *updAcc.getUses().begin(); + return forOp.getResult(use.getOperandNumber()); +} + +// Choose tile and block sizes for the candidate. Tile sizes are determined +// by input shapes and types. Block sizes are chosen to minimize number of +// tile loads/stores including tile register spills. +void setupBlockAndTileSizes(ArrayRef lhsShape, + ArrayRef rhsShape, + AmxDotOpCandidate &candidate) { + int64_t m = lhsShape[0]; + int64_t n = rhsShape[1]; + int64_t k = rhsShape[0]; + int64_t tileM = std::min(m, (int64_t)16); + int64_t tileN = std::min(n, (int64_t)16); + int64_t tileK = std::min( + k, (int64_t)512 / candidate.lhsTileElemTy.getIntOrFloatBitWidth()); + + int64_t accBlocksM = m / tileM; + int64_t accBlocksN = n / tileN; + + // All these sizes are power of 2. We have 8 tile registers and + // cannot use them all for accumulator. So, we will use up to 4 + // tiles for accumulator in a single block. + while (accBlocksM * accBlocksN > 4) { + if (accBlocksM > accBlocksN) + accBlocksM /= 2; + else + accBlocksN /= 2; + } + + candidate.tileM = tileM; + candidate.tileN = tileN; + candidate.tileK = tileK; + candidate.tilesInBlockM = accBlocksM; + candidate.tilesInBlockN = accBlocksN; +} + +// Check if vector transfer read/write operation uses a mask +// or involves a bounds check. +template bool hasMaskOrBoundsCheck(T op) { + auto inBounds = op.getInBounds(); + Value mask = op.getMask(); + bool hasBoundsCheck = + std::any_of(inBounds.begin(), inBounds.end(), [](Attribute attr) { + return !cast(attr).getValue(); + }); + return hasBoundsCheck || mask; +} + +// Check if a value is used only for a store and that this store can be +// replaced with tile stores. In this case fill appropriate fields in the +// candidate structure. +void findOutputBuffer(Value val, AmxDotOpCandidate &candidate) { + if (val.hasOneUse()) { + auto store = dyn_cast(*val.user_begin()); + if (store && !hasMaskOrBoundsCheck(store)) + candidate.outBuf = AmxBuffer{store.getSource(), store.getIndices()}; + candidate.origStore = store; + } +} + +// Check if specified ContractionOp can be lowered to AMX operations. +// If conversion is possible, then true is returned and candidate +// structure is filled with detailed transformation info. +bool isAmxCandidate(vector::ContractionOp op, bool supportInt8, + bool supportFp16, bool supportBf16, + AmxDotOpCandidate &candidate) { + MLIRContext *ctx = op.getContext(); + VectorType lhsTy = cast(op.getLhs().getType()); + VectorType rhsTy = cast(op.getRhs().getType()); + VectorType accTy = cast(op.getAcc().getType()); + VectorType resTy = cast(op.getType()); + + LDBG("Considering candidate op: " << op); + + // Contraction op is very generic. For now, we generate it only as a + // result of DotOp conversion. But still check it's what we expect. + if (!isDotOp(op)) + return false; + + // Check if input and output types match available hardware capabilities. + // If check is successful then tile element types are filled with types + // to use in AMX operations. + if (!checkElemTypes(lhsTy.getElementType(), rhsTy.getElementType(), + accTy.getElementType(), resTy.getElementType(), + supportInt8, supportFp16, supportBf16, candidate)) + return false; + + candidate.op = op; + setupBlockAndTileSizes(lhsTy.getShape(), rhsTy.getShape(), candidate); + candidate.keepAccOnTiles = isLoopCarriedAcc(op.getAcc()); + + // Can't keep acc in a tile the whole loop right now: + // https://github.com/llvm/llvm-project/issues/109481 + if (candidate.keepAccOnTiles) { + // We might not have enough tiles to hold accumulator. In this case + // keep it in a bufffer. + if (candidate.tilesInBlockM * candidate.tilesInBlockN > 1) { + LDBG("Accumulator is too big to keep on tiles. Keep it bufferized " + "insterad."); + candidate.keepAccOnTiles = false; + candidate.keepAccInBuf = true; + } else { + findOutputBuffer(getResValueForLoopCarriedAcc(op), candidate); + } + + // TODO: fix LLVM bug and remove this code. + LDBG("Avoid accumulator on tiles due to LLVM bug: " + "https://github.com/llvm/llvm-project/issues/109481."); + LDBG("Keep accumulator bufferized instead."); + candidate.keepAccOnTiles = false; + candidate.keepAccInBuf = true; + candidate.outBuf = AmxBuffer{}; + } else { + findOutputBuffer(op.getResult(), candidate); + } + + return true; +} + +// Cast vector to a specified element type using ext or trunc +// operations. Return the original value if it already matches +// the required element type. +Value maybeCast(Location loc, Value val, Type dstElemTy, + PatternRewriter &rewriter) { + VectorType srcTy = cast(val.getType()); + if (srcTy.getElementType() == dstElemTy) + return val; + + VectorType dstTy = srcTy.cloneWith(std::nullopt, dstElemTy); + if (srcTy.getElementType().isInteger()) { + if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) + return rewriter.create(loc, dstTy, val); + return rewriter.create(loc, dstTy, val); + } + + if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) + return rewriter.create(loc, dstTy, val); + return rewriter.create(loc, dstTy, val); +} + +// Get initial value for a loop-carried accumulator. +Value getInitAccValue(Value val) { + auto blockArg = cast(val); + auto forOp = cast(blockArg.getOwner()->getParentOp()); + int initValIdx = blockArg.getArgNumber() - forOp.getNumInductionVars(); + return forOp.getInitArgs()[initValIdx]; +} + +VectorType getSwizzledRhsTileType(VectorType origTileType) { + int64_t rowsPerGroup = 32 / origTileType.getElementTypeBitWidth(); + SmallVector shape({origTileType.getDimSize(0) / rowsPerGroup, + origTileType.getDimSize(1) * rowsPerGroup}); + return origTileType.cloneWith(shape, origTileType.getElementType()); +} + +AmxBuffer allocateTmpBuffer(Location loc, VectorType vecTy, + Operation *allocaPoint, PatternRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(allocaPoint); + auto memRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); + Value memRef = rewriter.create( + loc, memRefTy, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(2, zeroIdx); + return {memRef, indices}; +} + +// In AMX, element values shoud be packed to 32-bit groups that would be +// multiplied elementwise with following accumulation. It means that RHS +// needs to be pre-packed. E.g. for the following input +// B(0,0) B(0,1) B(0,2) ... B(0,15) +// B(1,0) B(1,1) B(1,2) ... B(1,15) +// B(2,0) B(2,1) B(2,2) ... B(2,15) +// B(3,0) B(3,1) B(3,2) ... B(3,15) +// and BF16/FP16 type we need to transform it to +// B(0,0) B(1,0) B(0,1), B(1,1) ... B(0,15) B(1,15) +// B(2,0) B(3,0) B(2,1), B(3,1) ... B(2,15) B(3,15) +// so that original columns are 32-bits now. In case of int8 type, the +// result would be: +// B(0,0) B(1,0) B(2,0), B(3,0) ... B(0,15) B(1,15), B(2,15) B(3,15) +void interleaveAndStore(Location loc, Value val, Value buf, + PatternRewriter &rewriter) { + LDBG("Repacking operand before storing to a buffer."); + VectorType valTy = cast(val.getType()); + int64_t rowsPerGroup = 32 / valTy.getElementTypeBitWidth(); + assert(rowsPerGroup == 2 || rowsPerGroup == 4); + assert(valTy.getDimSize(0) % rowsPerGroup == 0); + Value zeroIdx = rewriter.create(loc, 0); + for (int64_t i = 0; i < valTy.getDimSize(0); i += rowsPerGroup) { + Value row1, row2; + if (rowsPerGroup == 2) { + row1 = rewriter.create(loc, val, i); + row2 = rewriter.create(loc, val, i + 1); + } else { + row1 = rewriter.create( + loc, rewriter.create(loc, val, i), + rewriter.create(loc, val, i + 2)); + row2 = rewriter.create( + loc, rewriter.create(loc, val, i + 1), + rewriter.create(loc, val, i + 3)); + } + Value shuffled = rewriter.create(loc, row1, row2); + Value idx = rewriter.create(loc, i / rowsPerGroup); + rewriter.create(loc, shuffled, buf, + SmallVector({idx, zeroIdx})); + } +} + +// Prepare temporary buffers to be used for tile loads. If the original +// value can be directly loaded to tiles from its original memory, then +// use it instead. Return empty buffer if source value is all zeros and +// skipForZeros is set. +// +// If interleave flag is set, then pre-pack RHS before store. See +// interleaveAndStore for more details. +AmxBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, + bool skipForZeros, bool readOnly, + Operation *allocaPoint, + PatternRewriter &rewriter) { + LDBG("Preparing buffer (interleave=" << interleave + << ") for a vector: " << val); + auto valLoad = val.getDefiningOp(); + if (valLoad && !interleave && readOnly && !hasMaskOrBoundsCheck(valLoad)) { + Value memRef = valLoad.getSource(); + ValueRange indices = valLoad.getIndices(); + LDBG(" Reusing the original memref for a buffer: " << memRef); + return {memRef, indices}; + } + + if (skipForZeros && isZeroConst(val)) { + LDBG("Skip buffer for zero vector."); + return {}; + } + + auto vecTy = cast(val.getType()); + if (interleave) + vecTy = getSwizzledRhsTileType(vecTy); + AmxBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); + + if (interleave) { + interleaveAndStore(loc, val, buf.memRef, rewriter); + } else { + rewriter.create(loc, val, buf.memRef, buf.indices); + } + + return buf; +} + +// Return a buffer where the final result should be stored. If result can +// be directly stored to the output memory, then it is used as an output +// buffer. Otherwise, re-use accumulator buffer or create a new one. +AmxBuffer prepareResultBuffer(Location loc, Value val, const AmxBuffer &accBuf, + const AmxBuffer &outBuf, Operation *allocaPoint, + PatternRewriter &rewriter) { + if (!outBuf.empty()) { + LDBG("Output memory will be used for direct tile stores."); + return outBuf; + } + + if (!accBuf.empty()) { + LDBG("Result will be stored to accumulator buffer."); + return accBuf; + } + + LDBG("Allocating buffer for the result."); + return allocateTmpBuffer(loc, cast(val.getType()), allocaPoint, + rewriter); +} + +Value shiftIndex(Location loc, Value index, int64_t offs, + PatternRewriter &rewriter) { + if (!offs) + return index; + + // Do constant folding right away here for better code readability + // after the pass. + auto cstOp = dyn_cast(index.getDefiningOp()); + if (cstOp) { + int64_t oldVal = cast(cstOp.getValue()).getInt(); + return rewriter.create(loc, oldVal + offs); + } + + Value offsVal = rewriter.create(loc, offs); + return rewriter.create(loc, index.getType(), index, offsVal); +} + +SmallVector shiftIndices(Location loc, ArrayRef indices, + VectorType tileTy, int64_t tilesInBlockM, + int64_t tilesInBlockN, int64_t blockM, + int64_t blockN, int64_t tileM, int64_t tileN, + PatternRewriter &rewriter) { + int64_t blockOffsM = blockM * tilesInBlockM * tileTy.getDimSize(0); + int64_t blockOffsN = blockN * tilesInBlockN * tileTy.getDimSize(1); + int64_t tileOffsM = blockOffsM + tileM * tileTy.getDimSize(0); + int64_t tileOffsN = blockOffsN + tileN * tileTy.getDimSize(1); + return {shiftIndex(loc, indices[0], tileOffsM, rewriter), + shiftIndex(loc, indices[1], tileOffsN, rewriter)}; +} + +Value loadTile(Location loc, VectorType tileTy, const AmxBuffer &buf, + int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, + int64_t blockN, int64_t tileM, int64_t tileN, + PatternRewriter &rewriter) { + auto indices = + shiftIndices(loc, buf.indices, tileTy, tilesInBlockM, tilesInBlockN, + blockM, blockN, tileM, tileN, rewriter); + return rewriter.create(loc, tileTy, buf.memRef, indices); +} + +void storeTile(Location loc, VectorType tileTy, Value val, const AmxBuffer &buf, + int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, + int64_t blockN, int64_t tileM, int64_t tileN, + PatternRewriter &rewriter) { + auto indices = + shiftIndices(loc, buf.indices, tileTy, tilesInBlockM, tilesInBlockN, + blockM, blockN, tileM, tileN, rewriter); + rewriter.create(loc, buf.memRef, indices, val); +} + +SmallVector> +loadBlockTiles(Location loc, VectorType tileTy, const AmxBuffer &buf, + int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, + int64_t blockN, PatternRewriter &rewriter) { + SmallVector> res(tilesInBlockM); + for (int64_t m = 0; m < tilesInBlockM; ++m) { + for (int64_t n = 0; n < tilesInBlockN; ++n) { + Value tile = buf.memRef + ? loadTile(loc, tileTy, buf, tilesInBlockM, + tilesInBlockN, blockM, blockN, m, n, rewriter) + : rewriter.create(loc, tileTy); + res[m].push_back(tile); + } + } + return res; +} + +// Move acc to a tile for the whole loop. It might be loads from memory or +// zero tiles. +SmallVector> +moveLoopAccToTiles(Location loc, VectorType tileTy, const AmxBuffer &buf, + int64_t tilesInBlockM, int64_t tilesInBlockN, + PatternRewriter &rewriter) { + LDBG("Loading accumulator to tiles before the loop."); + auto res = loadBlockTiles(loc, tileTy, buf, tilesInBlockM, tilesInBlockN, 0, + 0, rewriter); + + // TODO: add new block args into ForOp and return them instead. + // Yield directly uses them for now and will be patched after mul + // ops generation. + llvm_unreachable("Not yet supported."); + + return res; +} + +// Multiply two blocks. LHS block is preloaded to tiles with the following +// iteration over RHS. Accumulator values are updated in accTiles. +// Optionally, results can also be stored to accBuf. +void multiplyBlocksPreloadLhs(Location loc, VectorType lhsTileTy, + VectorType rhsTileTy, VectorType accTileTy, + const AmxBuffer &lhsBuf, const AmxBuffer &rhsBuf, + const AmxBuffer &accBuf, int64_t blockM, + int64_t blockN, int64_t blockK, + int64_t tilesInBlockM, int64_t tilesInBlockN, + SmallVector> &accTiles, + bool storeResult, PatternRewriter &rewriter) { + bool isInteger = accTileTy.getElementType().isInteger(); + SmallVector> lhsTiles = loadBlockTiles( + loc, lhsTileTy, lhsBuf, tilesInBlockM, 1, blockM, blockK, rewriter); + + for (int64_t tileN = 0; tileN < tilesInBlockN; ++tileN) { + Value rhsTile = loadTile(loc, rhsTileTy, rhsBuf, 1, tilesInBlockN, blockK, + blockN, 0, tileN, rewriter); + + for (int64_t tileM = 0; tileM < tilesInBlockM; ++tileM) { + if (isInteger) + accTiles[tileM][tileN] = + rewriter.create(loc, accTileTy, lhsTiles[tileM][0], + rhsTile, accTiles[tileM][tileN]); + else + accTiles[tileM][tileN] = + rewriter.create(loc, accTileTy, lhsTiles[tileM][0], + rhsTile, accTiles[tileM][tileN]); + + // Insert store here to better mix stores with multiplications. + if (storeResult) { + storeTile(loc, accTileTy, accTiles[tileM][tileN], accBuf, tilesInBlockM, + tilesInBlockN, blockM, blockN, tileM, tileN, rewriter); + } + } + } +} + +// Similar to multiplyBlocksPreloadLhs but here RHS is preloaded to tiles. +void multiplyBlocksPreloadRhs(Location loc, VectorType lhsTileTy, + VectorType rhsTileTy, VectorType accTileTy, + const AmxBuffer &lhsBuf, const AmxBuffer &rhsBuf, + const AmxBuffer &accBuf, int64_t blockM, + int64_t blockN, int64_t blockK, + int64_t tilesInBlockM, int64_t tilesInBlockN, + SmallVector> &accTiles, + bool storeResult, PatternRewriter &rewriter) { + bool isInteger = accTileTy.getElementType().isInteger(); + SmallVector> rhsTiles = loadBlockTiles( + loc, rhsTileTy, rhsBuf, 1, tilesInBlockN, blockK, blockN, rewriter); + + for (int64_t tileM = 0; tileM < tilesInBlockM; ++tileM) { + Value lhsTile = loadTile(loc, lhsTileTy, lhsBuf, tilesInBlockM, 1, blockM, + blockK, tileM, 0, rewriter); + + for (int64_t tileN = 0; tileN < tilesInBlockN; ++tileN) { + if (isInteger) + accTiles[tileM][tileN] = rewriter.create( + loc, accTileTy, lhsTile, rhsTiles[0][tileN], + accTiles[tileM][tileN]); + else + accTiles[tileM][tileN] = rewriter.create( + loc, accTileTy, lhsTile, rhsTiles[0][tileN], + accTiles[tileM][tileN]); + + // Insert store here to better mix stores with multiplications. + if (storeResult) { + storeTile(loc, accTileTy, accTiles[tileM][tileN], accBuf, tilesInBlockM, + tilesInBlockN, blockM, blockN, tileM, tileN, rewriter); + } + } + } +} + +LogicalResult convertCandidate(AmxDotOpCandidate &candidate, + PatternRewriter &rewriter) { + vector::ContractionOp op = candidate.op; + Location loc = op.getLoc(); + VectorType lhsTy = cast(op.getLhs().getType()); + VectorType rhsTy = cast(op.getRhs().getType()); + VectorType accTy = cast(op.getAcc().getType()); + VectorType resTy = cast(op.getResultType()); + VectorType lhsTileTy = + lhsTy.cloneWith(SmallVector({candidate.tileM, candidate.tileK}), + candidate.lhsTileElemTy); + VectorType rhsTileTy = getSwizzledRhsTileType( + rhsTy.cloneWith(SmallVector({candidate.tileK, candidate.tileN}), + candidate.rhsTileElemTy)); + VectorType accTileTy = + accTy.cloneWith(SmallVector({candidate.tileM, candidate.tileN}), + candidate.accTileElemTy); + + // If we don't work with a loop and want to directly store tiles into output + // memory, then use the original store as insertion point to have its buffer + // values available for generated code. + if (!candidate.keepAccInBuf && !candidate.keepAccOnTiles && + !candidate.outBuf.empty()) + rewriter.setInsertionPoint(candidate.origStore); + + Operation *allocaPoint = op; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Cast input data if required and prepare input buffer. It might be temporary + // buffers with stored vectors or the original input memory. + Value lhs = maybeCast(loc, op.getLhs(), candidate.lhsTileElemTy, rewriter); + AmxBuffer lhsBuf = + prepareTensorBuffer(loc, lhs, false, false, true, allocaPoint, rewriter); + + Value rhs = maybeCast(loc, op.getRhs(), candidate.rhsTileElemTy, rewriter); + AmxBuffer rhsBuf = + prepareTensorBuffer(loc, rhs, true, false, true, allocaPoint, rewriter); + + Value acc = maybeCast(loc, op.getAcc(), candidate.accTileElemTy, rewriter); + Value accToStore = acc; + scf::ForOp forOp; + if (candidate.keepAccInBuf || candidate.keepAccOnTiles) { + forOp = cast(op->getParentOp()); + accToStore = getInitAccValue(acc); + } + AmxBuffer accBuf; + { + // If accumulator is bufferized then we should move initial values before + // the loop. + OpBuilder::InsertionGuard g(rewriter); + if (candidate.keepAccInBuf) + rewriter.setInsertionPoint(forOp); + accBuf = + prepareTensorBuffer(loc, accToStore, false, !candidate.keepAccInBuf, + false, allocaPoint, rewriter); + } + + AmxBuffer resBuf = prepareResultBuffer( + loc, op.getResult(), accBuf, candidate.outBuf, allocaPoint, rewriter); + + SmallVector> accTiles; + if (candidate.keepAccOnTiles) + accTiles = + moveLoopAccToTiles(loc, accTileTy, accBuf, candidate.tilesInBlockM, + candidate.tilesInBlockN, rewriter); + + int64_t blocksInAccM = + accTy.getDimSize(0) / candidate.tileM / candidate.tilesInBlockM; + int64_t blocksInAccN = + accTy.getDimSize(1) / candidate.tileN / candidate.tilesInBlockN; + int64_t tilesInVectorK = lhsTy.getDimSize(1) / candidate.tileK; + for (int64_t blockM = 0; blockM < blocksInAccM; ++blockM) { + for (int64_t blockN = 0; blockN < blocksInAccN; ++blockN) { + if (!candidate.keepAccOnTiles) + accTiles = + loadBlockTiles(loc, accTileTy, accBuf, candidate.tilesInBlockM, + candidate.tilesInBlockN, blockM, blockN, rewriter); + + for (int64_t blocK = 0; blocK < tilesInVectorK; ++blocK) { + // We can store accumulator if it is the last block over K dimension. + // TODO: enable forward store for acc kept in tiles. + bool storeAcc = + !candidate.keepAccOnTiles && (blocK == (tilesInVectorK - 1)); + // We need to choose which block (LHS or RHS) to keep on tiles. + // E.g. for ACC block 4x1 tiles, LHS block is also 4 tiles, so + // we would use all tile registers trying to keep both ACC and + // LHS blocks on registers. To decrease register pressure, keep + // the smallest block on tiles. + if (candidate.tilesInBlockM <= candidate.tilesInBlockN) + multiplyBlocksPreloadLhs( + loc, lhsTileTy, rhsTileTy, accTileTy, lhsBuf, rhsBuf, resBuf, + blockM, blockN, blocK, candidate.tilesInBlockM, + candidate.tilesInBlockN, accTiles, storeAcc, rewriter); + else + multiplyBlocksPreloadRhs( + loc, lhsTileTy, rhsTileTy, accTileTy, lhsBuf, rhsBuf, resBuf, + blockM, blockN, blocK, candidate.tilesInBlockM, + candidate.tilesInBlockN, accTiles, storeAcc, rewriter); + } + } + } + + // TODO: For keepAccOnTiles fix YieldOp to use mul results. + // TODO: For keepAccOnTiles move all new forOp results to vector through a + // buffer. + if (candidate.keepAccOnTiles) + llvm_unreachable("Not yet supported."); + + if (candidate.keepAccInBuf) { + int resIdx = op.getResult().getUses().begin()->getOperandNumber(); + Value loopRes = forOp.getResult(resIdx); + LDBG( + "Loading buffererized accumulator to a vector to replace loop result."); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(forOp); + Value newVal = rewriter.create( + loc, cast(acc.getType()), resBuf.memRef, resBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceAllUsesWith(loopRes, newVal); + // For now, just use init value for unused ForOp result instead of + // its removal. + rewriter.replaceOp(op, op.getAcc()); + } else if (candidate.outBuf.empty()) { + LDBG("Loading the result to a vector to replace orig op result."); + Value newVal = rewriter.create( + loc, cast(acc.getType()), resBuf.memRef, resBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceOp(op, newVal); + } else { + LDBG("Removing original operation and its use."); + rewriter.eraseOp(*op.getResult().user_begin()); + rewriter.eraseOp(op); + } + + return success(); +} + +struct ConvertDotToAMX + : public triton::cpu::impl::ConvertDotToAMXBase { + ConvertDotToAMX() = default; + ConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16) { + this->convertInt8 = convertInt8; + this->convertFp16 = convertFp16; + this->convertBf16 = convertBf16; + } + + void runOnOperation() override { + if (!convertInt8 && !convertFp16 && !convertBf16) + return; + + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + SmallVector candidates; + mod->walk([this, &candidates](vector::ContractionOp op) { + AmxDotOpCandidate candidate; + if (isAmxCandidate(op, convertInt8, convertFp16, convertBf16, + candidate)) { + LLVM_DEBUG({ + LDBG("Found AMX candidate"); + LDBG(" Op: " << candidate.op); + LDBG(" LhsTileElemTy: " << candidate.lhsTileElemTy); + LDBG(" RhsTileElemTy: " << candidate.rhsTileElemTy); + LDBG(" AccTileElemTy: " << candidate.accTileElemTy); + LDBG(" TileM: " << candidate.tileM); + LDBG(" TileN: " << candidate.tileN); + LDBG(" TileK: " << candidate.tileK); + LDBG(" TilesInBlockM: " << candidate.tilesInBlockM); + LDBG(" TilesInBlockN: " << candidate.tilesInBlockN); + LDBG(" KeepAccOnTiles: " << candidate.keepAccOnTiles); + LDBG(" KeepAccInBuf: " << candidate.keepAccInBuf); + LDBG(" Has output buffer: " << !candidate.outBuf.empty()); + }); + candidates.push_back(candidate); + } + return WalkResult::advance(); + }); + + for (auto &candidate : candidates) { + LDBG("Starting conversion of candidate: " << candidate.op); + PatternRewriter rewriter(context); + rewriter.setInsertionPoint(candidate.op); + if (succeeded(convertCandidate(candidate, rewriter))) { + LDBG("Conversion succeeded!"); + } else { + LDBG("Conversion failed!"); + } + } + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotToAMX() { + return std::make_unique(); +} + +std::unique_ptr> +createConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16) { + return std::make_unique(convertInt8, convertFp16, + convertBf16); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 9c4ca64e90b2..a4d86afa5e84 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -12,6 +12,7 @@ #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" #include "llvm/IR/Constants.h" #include "llvm/Support/TargetSelect.h" @@ -19,6 +20,10 @@ #include #include +#include +#include +#include + namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { @@ -76,6 +81,11 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { bool useHorizontalSum) { pm.addPass(mlir::triton::cpu::createConvertDotProduct(useHorizontalSum)); }); + m.def("add_convert_dot_to_amx", [](mlir::PassManager &pm, bool convertInt8, + bool convertFp16, bool convertBf16) { + pm.addPass(mlir::triton::cpu::createConvertDotToAMX( + convertInt8, convertFp16, convertBf16)); + }); m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm, bool promote_bf16_to_fp32, bool convert_mixed_precision_matmul, bool promote_lib_math_to_fp32) { @@ -121,10 +131,10 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { mlir::ConvertVectorToLLVMPassOptions opts; opts.reassociateFPReductions = reassoc_fp_reduction; // opts.force32BitVectorIndices = true; - // opts.amx = false; + opts.amx = true; // opts.armNeon = false; // opts.armSVE = false; - // opts.x86Vector = false; + opts.x86Vector = true; pm.addPass(mlir::createConvertVectorToLLVMPass(opts)); }); m.def("add_lower_affine", [](mlir::PassManager &pm) { @@ -148,11 +158,24 @@ void init_triton_cpu(py::module &&m) { auto passes = m.def_submodule("passes"); init_triton_cpu_passes_ttcpuir(passes.def_submodule("ttcpuir")); + m.def("enable_amx", []() -> bool { + // AMX usage requires extended XSTATE which is disabled by default. We + // need to request access to AMX so that XSTATE was dynamically extended + // on the first AMX usage instead of issuing SIGILL. + // See https://www.kernel.org/doc/Documentation/x86/xstate.rst for more + // details. + constexpr int XFEATURE_XTILEDATA = 18; + if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) + return false; + return true; + }); + m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; registry.insert(); mlir::triton::cpu::registerTritonOpScalarizeExternalModels(registry); + mlir::registerAMXDialectTranslation(registry); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); }); From 9fac57add9c74fbb290715791227046076e4090d Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Mon, 21 Oct 2024 10:18:58 -0400 Subject: [PATCH 124/165] Implement more libdevice functions using extern_elementwise (#161) This PR adds support for libdevice functions that don't map cleanly to a MathOp. We implement them using tt.extern_elementwise instead, indicating which Sleef function to use. While tt.extern_elementwise contains fields for the library path and name, the CUDA backend ignores those fields as it always uses the NVIDIA's libdevice library. We take a similar approach here and assume all extern calls go to the Sleef library. One difference though is that we need to select our Sleef function based on the number of elements of the vector, which is done by interpolating this number into the symbol name. To indicate where this interpolation should occur, I have made `%(numel)` into a special string value. This allows us to reuse tt.extern_elementwise without adding any extra attributes. --- .../Dialect/TritonCPU/IR/TritonCPUOps.td | 21 ++++++++ .../Dialect/TritonCPU/IR/TritonCPUTypes.td | 2 + lib/Dialect/TritonCPU/IR/Ops.cpp | 11 ++++ python/triton/language/extra/cpu/libdevice.py | 27 ++++++++++ .../cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp | 52 ++++++++++++++++--- .../ConvertElementwiseOps.cpp | 4 ++ 6 files changed, 109 insertions(+), 8 deletions(-) diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index 23e0fd8fc564..3551984df7c7 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -19,6 +19,27 @@ class TTC_Op traits = []> : !listconcat(traits, [])> { } +// +// External Elementwise op +// +def TTC_ExternElementwiseOp : TTC_Op<"extern_elementwise", [Elementwise, + SameOperandsAndResultEncoding, + SameVariadicOperandSize, + DeclareOpInterfaceMethods]> { + + let description = [{ + Similar to TT_ExternElementwiseOp, but only supports calls to libsleef at the moment. + The string "%s(numel)" in $symbol will be interpolated with the number of elements of + the vector argument(s). + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs TTC_Type:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; +} + def TTC_ExtractMemRefOp : TTC_Op<"extract_memref", [NoMemoryEffect]> { let summary = "Extract base memref from a block pointer"; diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td index 4bd64213db4b..d6ac013804c8 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td @@ -26,4 +26,6 @@ def TTC_TokenType : TTC_TypeDef<"Token", "token"> { def TTC_Vector : VectorOf<[TT_Float, TT_Int]>; +def TTC_Type : AnyTypeOf<[TT_Float, TT_Int, TTC_Vector]>; + #endif diff --git a/lib/Dialect/TritonCPU/IR/Ops.cpp b/lib/Dialect/TritonCPU/IR/Ops.cpp index d626ce3902a9..358ab418ceba 100644 --- a/lib/Dialect/TritonCPU/IR/Ops.cpp +++ b/lib/Dialect/TritonCPU/IR/Ops.cpp @@ -15,4 +15,15 @@ LogicalResult PrintOp::verify() { return success(); } +void ExternElementwiseOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + } // namespace mlir::triton::cpu diff --git a/python/triton/language/extra/cpu/libdevice.py b/python/triton/language/extra/cpu/libdevice.py index d1b410fdd19b..438f49cacf51 100644 --- a/python/triton/language/extra/cpu/libdevice.py +++ b/python/triton/language/extra/cpu/libdevice.py @@ -129,6 +129,33 @@ def trunc(arg0, _builder=None): return core.tensor(_builder.create_trunc(arg0.handle), arg0.type) +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise( + "", "", [arg0], { + (core.dtype("fp32"), ): ("Sleef_ceilf%(numel)", core.dtype("fp32")), + (core.dtype("fp64"), ): ("Sleef_ceild%(numel)", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Sleef_powf%(numel)_u10", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Sleef_powd%(numel)_u10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise( + "", "", [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("Sleef_fmodf%(numel)", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("Sleef_fmodd%(numel)", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + @jit def _const(v, dtype): """ diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index 68aa6c0bee0c..2d3087b5da2c 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -5,6 +5,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -124,7 +125,8 @@ struct DecomposeToNativeVecs : public OpRewritePattern { return rewriter.create(loc, val, indices); }); - Value subRes = rewriter.create(loc, subResTy, subInputs); + Value subRes = + rewriter.create(loc, subResTy, subInputs, op->getAttrs()); newRes = rewriter.create(loc, subRes, newRes, indices); } @@ -192,20 +194,21 @@ class SleefNameGenerator { std::string ulpSuffix; }; -template struct VecOpToVecLib : public OpRewritePattern { +template +struct OpToVecLibConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; - VecOpToVecLib(MLIRContext *context, GetVecFnNameFn getVecFnName) - : OpRewritePattern(context), getVecFnName(getVecFnName) {} + virtual std::string getVecFnName(OpT op, unsigned bitwidth, + unsigned numel) const = 0; LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { VectorType vecTy = dyn_cast(op.getType()); if (!vecTy || vecTy.getRank() > 1) return failure(); - auto fnName = getVecFnName(vecTy.getElementTypeBitWidth(), - vecTy.getNumElements(), op->getOperands()); + auto fnName = getVecFnName(op, vecTy.getElementTypeBitWidth(), + vecTy.getNumElements()); if (fnName.empty()) return failure(); @@ -229,9 +232,37 @@ template struct VecOpToVecLib : public OpRewritePattern { op->getOperands()); return success(); } +}; + +template +struct VecOpToVecLibConversion : public OpToVecLibConversion { +public: + VecOpToVecLibConversion(MLIRContext *context, GetVecFnNameFn getVecFnName) + : OpToVecLibConversion(context), getVecFnNameImpl(getVecFnName) {} + + std::string getVecFnName(OpT op, unsigned bitwidth, + unsigned numel) const override { + return getVecFnNameImpl(bitwidth, numel, op->getOperands()); + } private: - GetVecFnNameFn getVecFnName; + GetVecFnNameFn getVecFnNameImpl; +}; + +struct ExternElementwiseOpConversion + : public OpToVecLibConversion { + using OpToVecLibConversion::OpToVecLibConversion; + + std::string getVecFnName(triton::cpu::ExternElementwiseOp op, + unsigned bitwidth, unsigned numel) const override { + auto fnName = op.getSymbol(); + auto numelIdx = fnName.find("%(numel)"); + if (numelIdx == StringRef::npos) + return fnName.str(); + return (fnName.take_front(numelIdx) + Twine(numel) + + fnName.drop_front(numelIdx + 8)) + .str(); + } }; template @@ -239,7 +270,8 @@ void populatePatternsForOp(RewritePatternSet &patterns, GetVecFnNameFn getVecFnName) { patterns.add>(patterns.getContext()); patterns.add>(patterns.getContext()); - patterns.add>(patterns.getContext(), getVecFnName); + patterns.add>(patterns.getContext(), + getVecFnName); } struct MathToVecLibPass @@ -273,6 +305,10 @@ struct MathToVecLibPass } } + patterns.add>( + patterns.getContext()); + patterns.add(patterns.getContext()); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 4c37524e1b5f..87e0914e1e41 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -59,6 +59,7 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -249,6 +250,9 @@ struct ConvertElementwiseOps typeConverter, context); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); From 175d629fc1f6fce0568e63648e9573e87918c0aa Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Mon, 21 Oct 2024 19:24:01 -0400 Subject: [PATCH 125/165] Fix compilation when ARCH_REQ_XCOMP_PERM isn't defined (#163) Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- third_party/cpu/triton_cpu.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index a4d86afa5e84..0e0ee2757ae4 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -159,6 +159,7 @@ void init_triton_cpu(py::module &&m) { init_triton_cpu_passes_ttcpuir(passes.def_submodule("ttcpuir")); m.def("enable_amx", []() -> bool { +#ifdef ARCH_REQ_XCOMP_PERM // AMX usage requires extended XSTATE which is disabled by default. We // need to request access to AMX so that XSTATE was dynamically extended // on the first AMX usage instead of issuing SIGILL. @@ -168,6 +169,9 @@ void init_triton_cpu(py::module &&m) { if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) return false; return true; +#else + return false; +#endif }); m.def("load_dialects", [](mlir::MLIRContext &context) { From d466759678c8e7018e48747c0f508bcc051d8f24 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Mon, 21 Oct 2024 16:44:18 -0700 Subject: [PATCH 126/165] [CPU] Drop MLIR prefix in ScalarizeInterface (#164) --- third_party/cpu/include/ScalarizePass/CMakeLists.txt | 5 ++++- third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/third_party/cpu/include/ScalarizePass/CMakeLists.txt b/third_party/cpu/include/ScalarizePass/CMakeLists.txt index f03fb94e2b1a..4af0f9490fb9 100644 --- a/third_party/cpu/include/ScalarizePass/CMakeLists.txt +++ b/third_party/cpu/include/ScalarizePass/CMakeLists.txt @@ -1 +1,4 @@ -add_mlir_interface(ScalarizeInterface) +set(LLVM_TARGET_DEFINITIONS ScalarizeInterface.td) +mlir_tablegen(ScalarizeInterface.h.inc -gen-op-interface-decls) +mlir_tablegen(ScalarizeInterface.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(ScalarizeInterfaceIncGen) diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index 47023eff75be..0c097fb5923e 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -16,7 +16,7 @@ add_triton_library(TritonToTritonCPU DEPENDS TritonToTritonCPUPassIncGen - MLIRScalarizeInterfaceIncGen + ScalarizeInterfaceIncGen MLIRDialectUtils LINK_LIBS PUBLIC From efa03d9bd16cc62d98c8c58df7c932022bf24db9 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Mon, 21 Oct 2024 22:44:26 -0400 Subject: [PATCH 127/165] Pad size 2 vectors to size 4 when lowering extern_elementwise ops (#162) libsleef does not implement 2-element functions, so for libdevice functions implemented via extern_elementwise ops that rely wholly on libsleef implementations (as opposed to MathOps which can be lowered to native instructions), we need to pad those vectors to size 4. This allows us to enable test_math.py for all the functions introduced in https://github.com/triton-lang/triton-cpu/pull/161. --- python/test/unit/cpu/test_math.py | 47 ++++++++++---- .../cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp | 63 ++++++++++++++++++- 2 files changed, 96 insertions(+), 14 deletions(-) diff --git a/python/test/unit/cpu/test_math.py b/python/test/unit/cpu/test_math.py index 793bd0e6f7e7..958913e7f9f1 100644 --- a/python/test/unit/cpu/test_math.py +++ b/python/test/unit/cpu/test_math.py @@ -1,3 +1,4 @@ +import inspect import os import pytest import torch @@ -28,9 +29,9 @@ def is_cpu(): scalar_sizes = [1, 4, 16, 64] -def check_num_vec_calls(meta, vec_lib, dtype_str, size): +def check_num_vec_calls(meta, vec_lib, dtype_str, size, is_always_extern=False): # Check generated code calls vector math function - # FP16 and BF16 are casted to FP32 for math ops + # FP16 and BF16 are cast to FP32 for math ops elem_size = 8 if dtype_str == "float64" else 4 data_size = size * elem_size if data_size > 64: @@ -38,7 +39,7 @@ def check_num_vec_calls(meta, vec_lib, dtype_str, size): elif data_size >= 16: num_vec_calls = 1 else: - num_vec_calls = 0 + num_vec_calls = 1 if is_always_extern else 0 assert meta.asm["asm"].count(lib_prefix[vec_lib]) == num_vec_calls @@ -73,28 +74,45 @@ def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): chain(product(["libsleef", "libmvec"], vec_sizes), product([None], scalar_sizes))) @pytest.mark.parametrize("dtype_str", float_dtypes) @pytest.mark.parametrize("math_fn", [ - "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "cos", "cosh", "erf", "exp", "exp2", "expm1", "floor", - "isnan", "isinf", "log", "log1p", "log2", "log10", "rsqrt", "signbit", "sin", "sinh", "sqrt", "tan", "tanh", "trunc" + "acos", "acosh", "asin", "asinh", "atan", "atanh", "cbrt", "ceil", "cos", "cosh", "erf", "exp", "exp2", "expm1", + "floor", "fmod", "isnan", "isinf", "log", "log1p", "log2", "log10", "pow", "rsqrt", "signbit", "sin", "sinh", + "sqrt", "tan", "tanh", "trunc" ]) def test_libdevice_math_fn(vec_lib, dtype_str, math_fn, size, device): if not is_cpu(): pytest.skip("This test is CPU-specific") if vec_lib == "libmvec" and arch != "x86_64": pytest.skip("Vectorized libm calls are supported for x86 target only.") + if math_fn in {"ceil", "fmod", "pow"}: + if vec_lib != "libsleef": + pytest.skip("extern_elementwise only supports libsleef") + if dtype_str not in {"float32", "torch.float64"}: + pytest.skip(f"{math_fn} only supports fp32, fp64") @triton.jit - def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + def unary_kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): idxs = tl.arange(0, BLOCK_SIZE) x = tl.load(src + idxs) y = getattr(libdevice, MATH_FN)(x) tl.store(dst + idxs, y) - src = torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) + @triton.jit + def binary_kernel(x_ptr, y_ptr, out_ptr, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): + idxs = tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + idxs) + y = tl.load(y_ptr + idxs) + result = getattr(libdevice, MATH_FN)(x, y) + tl.store(out_ptr + idxs, result) + + signature = inspect.signature(getattr(libdevice, math_fn)) + num_params = len(signature.parameters) + inputs = [torch.rand((size, ), dtype=getattr(torch, dtype_str), device=device) for _ in range(num_params)] # Customize inputs if math_fn == "acosh": - src = src.abs() + 1 + inputs[0] = inputs[0].abs() + 1 if math_fn == "isnan" or math_fn == "isinf": indices = torch.randint(low=0, high=size, size=(size // 2, ), device=device) + src = inputs[0] for i in indices: if math_fn == "isnan": src[i] = float("nan") @@ -103,12 +121,13 @@ def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): # Generate reference output if math_fn == "cbrt": - ref = src.pow(1 / 3) + ref = inputs[0].pow(1 / 3) else: - ref = getattr(src, math_fn)() + ref = getattr(inputs[0], math_fn)(*inputs[1:]) - res = torch.empty(src.shape, dtype=ref.dtype, device=device) - meta = kernel[(1, )](src, res, MATH_FN=math_fn, BLOCK_SIZE=size, vec_lib=vec_lib) + res = torch.empty(inputs[0].shape, dtype=ref.dtype, device=device) + kernel = unary_kernel if num_params == 1 else binary_kernel + meta = kernel[(1, )](*inputs, res, MATH_FN=math_fn, BLOCK_SIZE=size, vec_lib=vec_lib) torch.testing.assert_close(ref, res) if vec_lib is None: @@ -119,7 +138,9 @@ def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr): "libmvec": {"expm1", "floor", "isnan", "isinf", "rsqrt", "signbit", "sqrt", "trunc"}, "libsleef": {"isnan", "isinf", "rsqrt", "signbit"}, } + # These are always implemented with extern library calls + always_extern = {"ceil", "fmod", "pow"} if math_fn not in native_impls[vec_lib]: - check_num_vec_calls(meta, vec_lib, dtype_str, size) + check_num_vec_calls(meta, vec_lib, dtype_str, size, is_always_extern=math_fn in always_extern) else: assert meta.asm["asm"].count(lib_prefix[vec_lib]) == 0 diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index 2d3087b5da2c..0d3b2a368e44 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -2,12 +2,14 @@ #include "cpu/include/TritonCPUToLLVM/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" @@ -136,6 +138,64 @@ struct DecomposeToNativeVecs : public OpRewritePattern { } }; +using ExternElementwiseOp = triton::cpu::ExternElementwiseOp; + +/* + * libsleef does not contain implementations for 2-element vectors, so we pad + * any such vectors to size 4 instead. + */ +struct PadSmallVecsForSleef : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PadSmallVecsForSleef(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(ExternElementwiseOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + VectorType vecTy = dyn_cast(op.getType()); + if (!vecTy) + return failure(); + + Type elemTy = vecTy.getElementType(); + if (!elemTy.isF32() && !elemTy.isF64()) + return failure(); + + int64_t numElems = vecTy.getNumElements(); + if (numElems >= 4) + return failure(); + + // Create a single-element vector for shuffle to use + auto paddingVec = rewriter.create( + loc, undef(elemTy), VectorType::get({1}, elemTy)); + // Assign indices such that shuffle will pad the original vector with + // elements from the paddingVec + SmallVector indices(4); + for (int i = 0; i < 4; ++i) { + if (i < numElems) + indices[i] = i; + else + indices[i] = numElems; + } + SmallVector newOperands; + for (auto argVal : op.getOperands()) { + auto shuf = + rewriter.create(loc, argVal, paddingVec, indices); + newOperands.push_back(shuf.getResult()); + } + // Update return type of extern call + auto newVecTy = VectorType::get({4}, elemTy); + auto extern_elem = rewriter.create( + loc, newVecTy, newOperands, op.getSymbol(), op.getPure()); + indices.resize(numElems); + // Truncate result to original size + rewriter.replaceOpWithNewOp(op, extern_elem.getResult(), + paddingVec, indices); + return success(); + } +}; + using GetVecFnNameFn = std::function; @@ -305,8 +365,9 @@ struct MathToVecLibPass } } - patterns.add>( + patterns.add>( patterns.getContext()); + patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) From 4b83250b1f1163978f26d6c4959c40a56232c74b Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Mon, 21 Oct 2024 21:24:00 -0700 Subject: [PATCH 128/165] Rebase onto upstream triton ff306da26b and fix regressions --- .github/workflows/integration-tests.yml | 2 + .../Dialect/TritonCPU/IR/TritonCPUOps.td | 4 +- third_party/cpu/backend/compiler.py | 2 + third_party/cpu/backend/driver.py | 4 ++ .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 37 ++++++++++++------- .../cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp | 4 ++ .../TritonCPUTransforms/ConvertDotProduct.cpp | 16 ++++---- .../lib/TritonToTritonCPU/ConvertDebugOps.cpp | 5 +-- 8 files changed, 48 insertions(+), 26 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index c86266885a98..1ea6e68c6c82 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -11,12 +11,14 @@ on: workflow_dispatch: # Disabled automatic triggers because tests in this workflow fail to run. # pull_request: +# # You can name your branch dev-foo to get CI runs. # branches-ignore: ['llvm-**'] # merge_group: # branches: [main, 'dev-**'] # types: [checks_requested] # push: # branches: [main] + concurrency: group: ${{ github.ref }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index 3551984df7c7..e987866d35cf 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -162,8 +162,8 @@ def TTC_AssertOp : TTC_Op<"assert", [MemoryEffects<[MemWrite]>]> { Takes a condition tensor, a message string, a file string, a function string, and a line number. If the condition is false, the message is printed, and the program is aborted. }]; - let arguments = (ins I1:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line); - let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)"; + let arguments = (ins I1:$condition, StrAttr:$message); + let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; } diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 2e346a259977..6150defb6128 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -42,6 +42,8 @@ class CPUOptions: max_num_imprecise_acc_default: int = 0 enable_fast_math: bool = True vec_lib: Optional[str] = 'libsleef' + # TODO: Try to enable it. + sanitize_overflow: bool = False # TODO: We may introduce CPU-specific options like # of cores. diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 9bc9db4379f8..c23d353d2735 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -375,3 +375,7 @@ def get_current_target(self): @staticmethod def is_active(): return True + + def get_benchmarker(self): + from triton.testing import do_bench + return do_bench diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index ed26c2196a44..a921bfeaecf0 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -256,20 +256,29 @@ struct AssertOpConversion Value message = LLVM::addStringToModule(loc, rewriter, "assertMessage_", makeNullTerminatedString(adaptor.getMessage())); - Value file = - LLVM::addStringToModule(loc, rewriter, "assertFile_", - makeNullTerminatedString(adaptor.getFile())); - Value func = - LLVM::addStringToModule(loc, rewriter, "assertFunc_", - makeNullTerminatedString(adaptor.getFunc())); - SmallVector args{getPid(op, 0), - getPid(op, 1), - getPid(op, 2), - op.getCondition(), - message, - file, - i32_val(adaptor.getLine()), - func}; + + // Based on lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp. + StringRef fileStr = "unknown"; + StringRef funcStr = "unknown"; + int line = 0; + int col = 0; + + while (auto callLoc = dyn_cast(loc)) + loc = callLoc.getCallee(); + + if (auto fileLineColLoc = dyn_cast(loc)) { + fileStr = fileLineColLoc.getFilename(); + line = fileLineColLoc.getLine(); + col = fileLineColLoc.getColumn(); + } + + Value file = LLVM::addStringToModule(loc, rewriter, "assertFile_", + makeNullTerminatedString(fileStr)); + Value func = LLVM::addStringToModule(loc, rewriter, "assertFunc_", + makeNullTerminatedString(funcStr)); + SmallVector args{getPid(op, 0), getPid(op, 1), getPid(op, 2), + op.getCondition(), message, file, + i32_val(line), func}; call(getAssertFuncDecl(rewriter), args); rewriter.eraseOp(op); return success(); diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp index 0f0193da57cf..99962da6546a 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -217,6 +217,10 @@ struct CallOpConversion : public ConvertOpToLLVMPattern { auto newCallOp = rewriter.create( callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promotedOperands, callOp->getAttrs()); + newCallOp.getProperties().setOpBundleSizes( + rewriter.getDenseI32ArrayAttr({})); + newCallOp.getProperties().setOperandSegmentSizes( + {static_cast(promotedOperands.size()), 0}); return newCallOp; } diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp index 93cccb60fb1b..da96eea967cd 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp @@ -204,7 +204,7 @@ struct ConvertMulSumToDotHorizontalSum SmallVector resultTypes = {outResTy}; // TODO: this intrinsic is hard-coded for Arm Neon - llvm::StringRef bfdotIntrinsic("llvm.aarch64.neon.bfdot.v4f32.v8bf16"); + auto bfdot = StringAttr::get(ctx, "llvm.aarch64.neon.bfdot.v4f32.v8bf16"); SmallVector args; for (int64_t idx = 0; idx < numOfBfdotOps; idx += 1) { @@ -221,7 +221,8 @@ struct ConvertMulSumToDotHorizontalSum // pair of adjacent bf16 elements in the source vectors (8 bf16), and // output 4 fp32 elements. auto callIntrOp = rewriter.create( - loc, resultTypes, bfdotIntrinsic, args, LLVM::FastmathFlags::fast); + loc, resultTypes, bfdot, args, + LLVM::FastmathFlagsAttr::get(ctx, LLVM::FastmathFlags::fast)); outRes[outIdx] = callIntrOp.getResult(0); } } @@ -231,13 +232,14 @@ struct ConvertMulSumToDotHorizontalSum resultTypes = {resTy.getElementType()}; // TODO: this intrinsic is hard-coded for Arm Neon - llvm::StringRef horizSumIntrinsic("llvm.aarch64.neon.faddv.f32.v4f32"); + auto horzSum = StringAttr::get(ctx, "llvm.aarch64.neon.faddv.f32.v4f32"); for (int64_t outIdx = 0; outIdx < numOfOutputChannels; outIdx += 1) { args = {outRes[outIdx]}; // This horizontal sum intrinsic will sum all fp32 elements in the source // vector into a single fp32 element auto callIntrOp = rewriter.create( - loc, resultTypes, horizSumIntrinsic, args, LLVM::FastmathFlags::fast); + loc, resultTypes, horzSum, args, + LLVM::FastmathFlagsAttr::get(ctx, LLVM::FastmathFlags::fast)); res = rewriter.create(loc, callIntrOp.getResult(0), res, outIdx); } @@ -398,7 +400,7 @@ struct ConvertMulSumToDotPack loc, fullResTy, rewriter.getZeroAttr(fullResTy)); SmallVector resultTypes = {subResTy}; // TODO: this intrinsic is hard-coded for Arm Neon - llvm::StringRef bfdotIntrinsic("llvm.aarch64.neon.bfdot.v4f32.v8bf16"); + auto bfdot = StringAttr::get(ctx, "llvm.aarch64.neon.bfdot.v4f32.v8bf16"); SmallVector args; SmallVector subRes(numOfOutputRegs); @@ -429,8 +431,8 @@ struct ConvertMulSumToDotPack // each pair of adjacent bf16 elements in the source vectors // (8 bf16), and output 4 fp32 elements. auto callIntrOp = rewriter.create( - loc, resultTypes, bfdotIntrinsic, args, - LLVM::FastmathFlags::fast); + loc, resultTypes, bfdot, args, + LLVM::FastmathFlagsAttr::get(ctx, LLVM::FastmathFlags::fast)); subRes[outIdx] = callIntrOp.getResult(0); } } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp index 8a83156e4c52..72a11c510526 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -86,9 +86,8 @@ struct AssertOpConversion : public OpConversionPattern { cast(condition.getType()).getRank(), true); condition = rewriter.create( loc, condition, acc, dimsToReduce, vector::CombiningKind::AND); - rewriter.replaceOpWithNewOp( - op, condition, op.getMessage(), op.getFile(), op.getFunc(), - op.getLine()); + rewriter.replaceOpWithNewOp(op, condition, + op.getMessage()); return success(); } }; From 6cf9ff03d851b77a1666e76358aedf4b18226770 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 23 Oct 2024 11:34:07 -0500 Subject: [PATCH 129/165] Simple fixes to build on MacOSx (#165) --- third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp | 5 +++-- third_party/cpu/triton_cpu.cc | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index 0d3b2a368e44..65ff0bd41e72 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -231,11 +231,12 @@ class MvecNameGenerator { class SleefNameGenerator { public: SleefNameGenerator(StringRef baseName, unsigned ulp = 10) - : baseName(baseName), ulpSuffix(4, '\0') { + : baseName(baseName), ulpSuffix(5, '\0') { if (ulp == 0) ulpSuffix = ""; else - sprintf(ulpSuffix.data(), "_u%02u", ulp); + // snprintf inserts '\0' at the end + snprintf(ulpSuffix.data(), ulpSuffix.size(), "_u%02u", ulp); } std::string operator()(unsigned bitwidth, unsigned numel, diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 0e0ee2757ae4..9db53e1d0c71 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -20,7 +20,9 @@ #include #include +#ifdef __linux__ #include +#endif // __linux__ #include #include @@ -159,7 +161,7 @@ void init_triton_cpu(py::module &&m) { init_triton_cpu_passes_ttcpuir(passes.def_submodule("ttcpuir")); m.def("enable_amx", []() -> bool { -#ifdef ARCH_REQ_XCOMP_PERM +#if defined(__linux__) && defined(ARCH_REQ_XCOMP_PERM) // AMX usage requires extended XSTATE which is disabled by default. We // need to request access to AMX so that XSTATE was dynamically extended // on the first AMX usage instead of issuing SIGILL. @@ -171,7 +173,7 @@ void init_triton_cpu(py::module &&m) { return true; #else return false; -#endif +#endif // __linux__ && ARCH_REQ_XCOMP_PERM }); m.def("load_dialects", [](mlir::MLIRContext &context) { From e24a63ace07bf0e4b81dde92035fa360b03b343f Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 23 Oct 2024 18:29:53 -0500 Subject: [PATCH 130/165] Fix trailing null char in ulpSuffix (#166) --- third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index 65ff0bd41e72..b68d5a7473d0 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -231,12 +231,14 @@ class MvecNameGenerator { class SleefNameGenerator { public: SleefNameGenerator(StringRef baseName, unsigned ulp = 10) - : baseName(baseName), ulpSuffix(5, '\0') { - if (ulp == 0) + : baseName(baseName), ulpSuffix(4, '\0') { + if (ulp == 0) { ulpSuffix = ""; - else - // snprintf inserts '\0' at the end - snprintf(ulpSuffix.data(), ulpSuffix.size(), "_u%02u", ulp); + } else { + char buf[5]; // 4 char suffix + '\0' added by snprintf + snprintf(buf, 5, "_u%02u", ulp); + ulpSuffix = buf; + } } std::string operator()(unsigned bitwidth, unsigned numel, From b2f8c992dfcfbd3a5986963478f8a6c23e2cfa17 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Wed, 23 Oct 2024 19:36:49 -0700 Subject: [PATCH 131/165] Rebase onto upstream triton 4a5431159a and fix regressions --- python/triton/runtime/autotuner.py | 3 +- python/triton/testing.py | 60 +---------------- python/tutorials/01-vector-add.py | 21 +++--- python/tutorials/02-fused-softmax-cpu.py | 18 +++-- .../tutorials/03-matrix-multiplication-cpu.py | 16 ++--- python/tutorials/05-layer-norm.py | 4 +- .../matrix-vector-multiplication-bf16.py | 14 ++-- .../tutorials/matrix-vector-multiplication.py | 29 +++----- third_party/cpu/backend/driver.py | 67 +++++++++++++++++++ 9 files changed, 109 insertions(+), 123 deletions(-) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 23ff224998cd..339b79529537 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -164,8 +164,7 @@ def kernel_call(): self.post_hook(full_nargs, exception=None) try: - device = driver.active.get_current_target().backend - return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), device_type=device) + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) except (OutOfResources, CompileTimeAssertionFailure, PTXASError) as e: if verbose: print(f"Autotuning failed with {e}") diff --git a/python/triton/testing.py b/python/triton/testing.py index ed47eca834b5..5dadabaa1001 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -4,67 +4,10 @@ import statistics import subprocess import sys -import time from contextlib import contextmanager from typing import Any, Dict, List from . import language as tl from . import runtime -import triton - - -class CPUDeviceInterface: - - class HooksTimeAccessor: - - def __init__(self, di): - self.di = di - self.record_idx = 0 - - def elapsed_time(self, end_event) -> float: - total_time = 0 - for i in range(self.record_idx, end_event.record_idx): - total_time += self.di.kernel_times[i] - return total_time * 1000 - - def record(self): - self.record_idx = len(self.di.kernel_times) - - class TimerEvent: - - def __init__(self): - self.timer = 0 - - def elapsed_time(self, end_event) -> float: - return (end_event.timer - self.timer) * 1000 - - def record(self): - self.timer = time.perf_counter() - - def __init__(self): - self.kernel_times = [] - self.last_start = 0 - self.use_hooks = False - triton.compiler.CompiledKernel.launch_enter_hook = None - triton.compiler.CompiledKernel.launch_exit_hook = None - - def enable_hook_timing(self): - self.use_hooks = True - triton.compiler.CompiledKernel.launch_enter_hook = lambda arg: self._enter_hook() - triton.compiler.CompiledKernel.launch_exit_hook = lambda arg: self._exit_hook() - - def synchronize(self): - pass - - def _enter_hook(self): - self.last_start = time.perf_counter() - - def _exit_hook(self): - self.kernel_times.append(time.perf_counter() - self.last_start) - - def Event(self, enable_timing=True): - if self.use_hooks: - return CPUDeviceInterface.HooksTimeAccessor(self) - return CPUDeviceInterface.TimerEvent() def nvsmi(attrs): @@ -177,7 +120,8 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None, quantiles=None, return_mod return _summarize_statistics(ret, quantiles, return_mode) -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", measure_time_with_hooks=False): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean", + measure_time_with_hooks=False): """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 90f65b12e351..5c9cf2aa75e8 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -206,31 +206,26 @@ def benchmark(size, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles, device_type=DEVICE) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles, - device_type=DEVICE) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, None, False), quantiles=quantiles) elif provider == 'torch-cpu': # Note that we preallocate the output buffer here to only measure the kernel performance # without a large chunk of memory allocation. - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles, - device_type=DEVICE) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.add(x, y, out=output), quantiles=quantiles) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, DEVICE), quantiles=quantiles, - device_type=DEVICE) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, DEVICE), quantiles=quantiles) elif provider == 'triton-cpu-hooks': ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y, output, DEVICE), quantiles=quantiles, - device_type=DEVICE, measure_time_with_hooks=True) + measure_time_with_hooks=True) elif provider == 'triton-cpu-tiled': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles, - device_type=DEVICE) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles) elif provider == 'triton-cpu-tiled-hooks': ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled(x, y, output), quantiles=quantiles, - device_type=DEVICE, measure_time_with_hooks=True) + measure_time_with_hooks=True) elif provider == 'triton-cpu-tiled-tuned-hooks': ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled_with_st_threshold(x, y, output), - quantiles=quantiles, device_type=DEVICE, - measure_time_with_hooks=True) + quantiles=quantiles, measure_time_with_hooks=True) gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/02-fused-softmax-cpu.py b/python/tutorials/02-fused-softmax-cpu.py index 355277d4bb9c..ee298102c2c7 100644 --- a/python/tutorials/02-fused-softmax-cpu.py +++ b/python/tutorials/02-fused-softmax-cpu.py @@ -211,24 +211,22 @@ def benchmark(M, N, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'torch-cpu-native': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles) if provider == 'torch-cpu-jit': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles) if provider == 'torch-cpu-compile': compiled = torch.compile(naive_softmax) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles) if provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles) if provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles) if provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles) if provider == 'torch-gpu-native': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles) if provider == 'torch-gpu-jit': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 91378c5db836..2ace29240b9b 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -378,22 +378,18 @@ def benchmark(M, N, K, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) elif provider == 'torch-cpu-native': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles) elif provider == 'torch-cpu-compile': compiled = torch.compile(torch.matmul) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index bc5716b8792c..2ebad39be370 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -359,13 +359,13 @@ def y_fwd(): # forward pass if mode == 'forward': gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) - ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) # backward pass if mode == 'backward': y = y_fwd() gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) # noqa: F811, E704 ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, - grad_to_none=[x], rep=500, device_type=device) + grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/matrix-vector-multiplication-bf16.py b/python/tutorials/matrix-vector-multiplication-bf16.py index 7993e4090b20..76162ef68c4f 100644 --- a/python/tutorials/matrix-vector-multiplication-bf16.py +++ b/python/tutorials/matrix-vector-multiplication-bf16.py @@ -169,20 +169,16 @@ def benchmark(M, N, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) elif 'torch-cpu-native' in provider: - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles) elif 'torch-cpu-compile' in provider: ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_matmul(weight, x, out=output), - quantiles=quantiles, device_type=device) + quantiles=quantiles) elif 'triton-cpu' in provider: - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/matrix-vector-multiplication.py b/python/tutorials/matrix-vector-multiplication.py index 5d44ddf9c2c2..c9bb0b0f525e 100644 --- a/python/tutorials/matrix-vector-multiplication.py +++ b/python/tutorials/matrix-vector-multiplication.py @@ -173,38 +173,29 @@ def benchmark(M, N, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x), quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) elif provider == 'torch-cpu-native' or provider == 'torch-cpu-2d-native': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(weight, x, out=output), quantiles=quantiles) elif provider == 'torch-cpu-compile' or provider == 'torch-cpu-2d-compile': compiled = torch.compile(torch.matmul) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(weight, x, out=output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(weight, x, out=output), quantiles=quantiles) elif provider == 'torch-cpu-transpose-native': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(x, weight, out=output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(x, weight, out=output), quantiles=quantiles) elif provider == 'torch-cpu-transpose-compile': compiled = torch.compile(torch.matmul) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x, weight, out=output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x, weight, out=output), quantiles=quantiles) elif provider == 'torch-cpu-linear': weight = torch.nn.Linear(N, M, bias=False, device=weight.device, dtype=weight.dtype) - ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles, device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) elif provider == 'triton-cpu-linear': # torch.nn.Linear.forward does not take preallocated output buffer, so we also do no provide output buffer for fair comparison - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, None), quantiles=quantiles, - device_type=device) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, None), quantiles=quantiles) perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index c23d353d2735..403838ba0ab1 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -3,7 +3,9 @@ import importlib import importlib.resources import tempfile +import time +import triton import triton._C from triton.runtime.build import _build from triton.runtime.cache import get_cache_manager @@ -353,6 +355,61 @@ def __call__(self, *args, **kwargs): self.launch(*args, **kwargs) +class CPUDeviceInterface: + + class HooksTimeAccessor: + + def __init__(self, di): + self.di = di + self.record_idx = 0 + + def elapsed_time(self, end_event) -> float: + total_time = 0 + for i in range(self.record_idx, end_event.record_idx): + total_time += self.di.kernel_times[i] + return total_time * 1000 + + def record(self): + self.record_idx = len(self.di.kernel_times) + + class TimerEvent: + + def __init__(self): + self.timer = 0 + + def elapsed_time(self, end_event) -> float: + return (end_event.timer - self.timer) * 1000 + + def record(self): + self.timer = time.perf_counter() + + def __init__(self): + self.kernel_times = [] + self.last_start = 0 + self.use_hooks = False + triton.compiler.CompiledKernel.launch_enter_hook = None + triton.compiler.CompiledKernel.launch_exit_hook = None + + def enable_hook_timing(self): + self.use_hooks = True + triton.compiler.CompiledKernel.launch_enter_hook = lambda arg: self._enter_hook() + triton.compiler.CompiledKernel.launch_exit_hook = lambda arg: self._exit_hook() + + def synchronize(self): + pass + + def _enter_hook(self): + self.last_start = time.perf_counter() + + def _exit_hook(self): + self.kernel_times.append(time.perf_counter() - self.last_start) + + def Event(self, enable_timing=True): + if self.use_hooks: + return CPUDeviceInterface.HooksTimeAccessor(self) + return CPUDeviceInterface.TimerEvent() + + class CPUDriver(DriverBase): def __init__(self): @@ -372,6 +429,9 @@ def get_current_target(self): cpu_arch = llvm.get_cpu_tripple().split("-")[0] return GPUTarget("cpu", cpu_arch, 0) + def get_device_interface(self): + return CPUDeviceInterface() + @staticmethod def is_active(): return True @@ -379,3 +439,10 @@ def is_active(): def get_benchmarker(self): from triton.testing import do_bench return do_bench + + def get_empty_cache_for_benchmark(self): + import torch + + # A typical LLC size for high-end server CPUs are ~400MB. + cache_size = 512 * 1024 * 1024 + return torch.empty(int(cache_size // 4), dtype=torch.int, device='cpu') From 76b3225ddc2b10f11a62b1e3582c7ade98629468 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Fri, 25 Oct 2024 18:30:01 +0200 Subject: [PATCH 132/165] [Test][Autotuner] Skip use_cuda_graph for non cuda devices (#169) This commit skips test for non-cuda devices that exects to use_cuda_graph. Signed-off-by: Dmitrii Makarenko --- python/test/unit/runtime/test_autotuner.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index fa835eeeb8e7..8379f69e3804 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -4,6 +4,8 @@ import triton.language as tl import pytest +from triton._internal_testing import is_cuda + def do_bench(kernel_call, quantiles): return triton.testing.do_bench(kernel_call, quantiles=quantiles, warmup=1, rep=1) @@ -11,8 +13,10 @@ def do_bench(kernel_call, quantiles): @pytest.mark.parametrize('use_cuda_graph', [False, True]) def test_kwargs(use_cuda_graph: bool, device: str): - if use_cuda_graph and not torch.cuda.is_available(): + + if not is_cuda() and use_cuda_graph: pytest.xfail("CUDA is not available") + pytest.skip("Use cuda graph without cuda looks strange") M, N = 1024, 16 src = torch.randn(M * N, device=device) From c8c4bcea9f4f7207e4ab37d645269d626ff75007 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 28 Oct 2024 09:28:53 -0500 Subject: [PATCH 133/165] Add num_threads option to control threading per kernel invocation. (#170) Signed-off-by: Ilya Enkovich --- python/tutorials/02-fused-softmax-cpu.py | 10 +++---- .../tutorials/03-matrix-multiplication-cpu.py | 10 +++---- .../matrix-vector-multiplication-bf16.py | 12 +++++---- .../tutorials/matrix-vector-multiplication.py | 12 ++++----- third_party/cpu/backend/compiler.py | 5 ++++ third_party/cpu/backend/driver.py | 27 +++++++++---------- 6 files changed, 35 insertions(+), 41 deletions(-) diff --git a/python/tutorials/02-fused-softmax-cpu.py b/python/tutorials/02-fused-softmax-cpu.py index ee298102c2c7..e93ed9d37b50 100644 --- a/python/tutorials/02-fused-softmax-cpu.py +++ b/python/tutorials/02-fused-softmax-cpu.py @@ -100,7 +100,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n # We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. -def softmax(x, y=None): +def softmax(x, y=None, num_threads=0): n_rows, n_cols = x.shape # The block size is the smallest power of two greater than the number of columns in `x` BLOCK_SIZE = triton.next_power_of_2(n_cols) @@ -126,6 +126,7 @@ def softmax(x, y=None): n_cols, num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, + num_threads=num_threads, ) return y @@ -190,7 +191,6 @@ def softmax(x, y=None): args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): - import os # Currently compilation time is very long. Let's show the progress. print(f"Running {provider} with {M} x {N}...") @@ -201,10 +201,6 @@ def benchmark(M, N, provider): if device == 'cpu': y = torch.empty_like(x) triton.runtime.driver.set_active_to_cpu() - if 'single' in provider: - os.environ['TRITON_CPU_SINGLE_CORE'] = '1' - else: - os.unsetenv('TRITON_CPU_SINGLE_CORE') else: y = None triton.runtime.driver.set_active_to_gpu() @@ -218,7 +214,7 @@ def benchmark(M, N, provider): compiled = torch.compile(naive_softmax) ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(x), quantiles=quantiles) if provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y, num_threads=1), quantiles=quantiles) if provider == 'triton-cpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x, y), quantiles=quantiles) if provider == 'triton-gpu': diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 2ace29240b9b..c61a8098eac4 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -248,7 +248,7 @@ def matmul_kernel( # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. -def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" @@ -272,6 +272,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): c.stride(0), c.stride(1), # BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # + num_threads=num_threads, # ) return c @@ -359,7 +360,6 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor): args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(M, N, K, provider): - import os device = 'cpu' if 'cpu' in provider else 'cuda' a = torch.randn((M, K), device=device, dtype=torch.float32) @@ -368,10 +368,6 @@ def benchmark(M, N, K, provider): if device == 'cpu': c = torch.empty((M, N), device=a.device, dtype=a.dtype) triton.runtime.driver.set_active_to_cpu() - if 'single' in provider: - os.environ['TRITON_CPU_SINGLE_CORE'] = '1' - else: - os.unsetenv('TRITON_CPU_SINGLE_CORE') else: c = None triton.runtime.driver.set_active_to_gpu() @@ -387,7 +383,7 @@ def benchmark(M, N, K, provider): compiled = torch.compile(torch.matmul) ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c, num_threads=1), quantiles=quantiles) elif provider == 'triton-cpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) diff --git a/python/tutorials/matrix-vector-multiplication-bf16.py b/python/tutorials/matrix-vector-multiplication-bf16.py index 76162ef68c4f..9927d2be956a 100644 --- a/python/tutorials/matrix-vector-multiplication-bf16.py +++ b/python/tutorials/matrix-vector-multiplication-bf16.py @@ -50,6 +50,7 @@ def gemv( weight: torch.Tensor, x: torch.Tensor, output: torch.Tensor, + num_threads=0, ): assert weight.shape[1] == x.shape[0], "Incompatible dimensions" assert weight.is_contiguous() and x.is_contiguous(), "Input and weight must be contiguous" @@ -69,7 +70,8 @@ def gemv( # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), ) - gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N) + gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, + num_threads=num_threads) return output @@ -148,7 +150,6 @@ def gemv( args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(M, N, provider): - import os device = 'cpu' if 'cpu' in provider else 'cuda' weight = torch.randn((M, N), device=device, dtype=torch.bfloat16) @@ -157,11 +158,11 @@ def benchmark(M, N, provider): if device == 'cpu': output = torch.empty((M), device=x.device, dtype=x.dtype) triton.runtime.driver.set_active_to_cpu() + num_threads = 0 if 'single' in provider: - os.environ['TRITON_CPU_SINGLE_CORE'] = '1' + num_threads = 1 torch.set_num_threads(1) else: - os.unsetenv('TRITON_CPU_SINGLE_CORE') torch.set_num_threads(default_num_threads) else: output = None @@ -178,7 +179,8 @@ def benchmark(M, N, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_matmul(weight, x, out=output), quantiles=quantiles) elif 'triton-cpu' in provider: - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output, num_threads=num_threads), + quantiles=quantiles) perf = lambda ms: 2 * M * N * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/matrix-vector-multiplication.py b/python/tutorials/matrix-vector-multiplication.py index c9bb0b0f525e..06feca82893f 100644 --- a/python/tutorials/matrix-vector-multiplication.py +++ b/python/tutorials/matrix-vector-multiplication.py @@ -49,6 +49,7 @@ def gemv( weight: torch.Tensor, x: torch.Tensor, output: torch.Tensor, + num_threads=0, ): assert weight.shape[1] == x.shape[0], "Incompatible dimensions" assert weight.is_contiguous() and x.is_contiguous(), "Input and weight must be contiguous" @@ -68,7 +69,8 @@ def gemv( # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]), ) - gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N) + gemv_kernel[grid](output, weight, x, M, N, weight.stride(0), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, + num_threads=num_threads) return output @@ -146,7 +148,6 @@ def gemv( args={'M': 4096}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(M, N, provider): - import os device = 'cpu' if 'cpu' in provider else 'cuda' weight = torch.randn((M, N), device=device, dtype=torch.float32) @@ -155,10 +156,6 @@ def benchmark(M, N, provider): if device == 'cpu': output = torch.empty((M), device=x.device, dtype=x.dtype) triton.runtime.driver.set_active_to_cpu() - if 'single' in provider: - os.environ['TRITON_CPU_SINGLE_CORE'] = '1' - else: - os.unsetenv('TRITON_CPU_SINGLE_CORE') if 'transpose' in provider: weight = torch.transpose(weight, 0, 1) @@ -190,7 +187,8 @@ def benchmark(M, N, provider): weight = torch.nn.Linear(N, M, bias=False, device=weight.device, dtype=weight.dtype) ms, min_ms, max_ms = triton.testing.do_bench(lambda: weight.forward(x), quantiles=quantiles) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output, num_threads=1), + quantiles=quantiles) elif provider == 'triton-cpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: gemv(weight, x, output), quantiles=quantiles) elif provider == 'triton-cpu-linear': diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6150defb6128..a3e44e97120b 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -27,9 +27,14 @@ class CPUOptions: # GPU-specific options are used in several places. # For now, we just provide dummy values. backend_name: str = "cpu" + # These options provide compatibility with GPU kernel calls. + # All of them are ignored. num_warps: int = 0 num_stages: int = 0 num_ctas: int = 0 + # Max number of threads to be used for a kernel call. + # Zero value is used to utilize all available CPU cores. + num_threads: int = 0 cluster_dims: tuple = (1, 1, 1) extern_libs: dict = None debug: bool = False diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 403838ba0ab1..f3d3cace794e 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -229,7 +229,7 @@ def format_of(ty): return grids; }} -static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ +static void run_omp_kernels(uint32_t gridX, uint32_t gridY, uint32_t gridZ, int num_threads, kernel_ptr_t kernel_ptr {', ' + arg_decls if len(arg_decls) > 0 else ''}) {{ // TODO: Consider using omp collapse(3) clause for simplicity? size_t N = gridX * gridY * gridZ; if (N == 1) {{ @@ -238,10 +238,10 @@ def format_of(ty): }} auto all_grids = get_all_grids(gridX, gridY, gridZ); - if (getBoolEnv("TRITON_CPU_SINGLE_CORE")) {{ - if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) - printf("Single core launcher\\n"); + int max_threads = (num_threads > 0) ? num_threads : omp_get_max_threads(); + // Don't pay OMP overhead price when a single thread is used. + if (max_threads == 1) {{ for (size_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ); @@ -249,17 +249,8 @@ def format_of(ty): return; }} - std::optional max_threads = getIntEnv("TRITON_CPU_MAX_THREADS"); - if (max_threads.has_value()) - max_threads = std::max(1, std::min(max_threads.value(), omp_get_max_threads())); - else - max_threads = omp_get_max_threads(); - - if (getBoolEnv("TRITON_CPU_OMP_DEBUG")) - printf("N: %zu, max_threads: %d\\n", N, max_threads.value()); - // For now, use the default chunk size, total iterations / max_threads. -#pragma omp parallel for schedule(static) num_threads(max_threads.value()) +#pragma omp parallel for schedule(static) num_threads(max_threads) for (size_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ); @@ -285,6 +276,12 @@ def format_of(ty): void *pStream = PyLong_AsVoidPtr(py_obj_stream); kernel_ptr_t kernel_ptr = reinterpret_cast(pKrnl); + // Extract num_threads metadata. + int num_threads = 0; + PyObject *num_threads_attr = PyObject_GetAttrString(kernel_metadata, "num_threads"); + if (num_threads_attr && PyLong_Check(num_threads_attr)) + num_threads = PyLong_AsLong(num_threads_attr); + // extract launch metadata if (launch_enter_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); @@ -295,7 +292,7 @@ def format_of(ty): }} {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; - run_omp_kernels(gridX, gridY, gridZ, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); + run_omp_kernels(gridX, gridY, gridZ, num_threads, kernel_ptr {',' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''}); if(launch_exit_hook != Py_None){{ PyObject* args = Py_BuildValue("(O)", launch_metadata); From 857388630d2e77d96d3ed403eb712dabb4fe9f1a Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Mon, 28 Oct 2024 15:44:30 +0100 Subject: [PATCH 134/165] [TTC Print Memref] Simplify further multidimensional tensor printing (#160) This commit adds Memref type to possible inputs of print. Memref have strides and other supporting information to allow print multidimensional tensors. (2d, 3d etc) Such print will be added in the next pr. --- .../Dialect/TritonCPU/IR/TritonCPUOps.td | 2 +- .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 112 ++++---- .../lib/TritonToTritonCPU/ConvertDebugOps.cpp | 52 +++- third_party/cpu/runtime/cpu_runtime.cpp | 255 ++++++++++++++---- 4 files changed, 317 insertions(+), 104 deletions(-) diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index e987866d35cf..28eefd383f9c 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -145,7 +145,7 @@ def TTC_PrintOp : TTC_Op<"print", [MemoryEffects<[MemWrite]>]> { let arguments = (ins StrAttr:$prefix, BoolAttr:$hex, - Variadic>:$val, + Variadic>:$val, DenseI32ArrayAttr:$isSigned ); diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index a921bfeaecf0..a6d1487ebcc7 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -2,8 +2,6 @@ #include "Utility.h" #include "cpu/include/TritonCPUToLLVM/Passes.h" - -#include "mlir/Dialect/GPU/IR/GPUOps.h.inc" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -96,8 +94,8 @@ Value printfPromoteValue(RewriterBase &rewriter, Value value) { return value; } -LLVM::LLVMFuncOp getPrintFuncDecl(ConversionPatternRewriter &rewriter, - bool printf) { +LLVM::LLVMFuncOp getOrAddPrintFuncDecl(ConversionPatternRewriter &rewriter, + bool printf) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); StringRef funcName = printf ? "printf" : "triton_vector_print"; Operation *funcOp = moduleOp.lookupSymbol(funcName); @@ -122,15 +120,50 @@ LLVM::LLVMFuncOp getPrintFuncDecl(ConversionPatternRewriter &rewriter, funcType); } +LLVM::LLVMFuncOp +getOrAddPrintMemrefFuncDecl(ConversionPatternRewriter &rewriter) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + StringRef funcName = "triton_print_unranked_memref"; + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *ctx = rewriter.getContext(); + SmallVector argsType; + + SmallVector elemTypes; + elemTypes.push_back(i64_ty); + elemTypes.push_back(ptr_ty(ctx)); + Type structTy = struct_ty(elemTypes); + + argsType = {/*pid serialization*/ i32_ty, + i32_ty, + i32_ty, /*end pids*/ + ptr_ty(ctx), + structTy, + /*type sreialization*/ i32_ty, + i32_ty, + i32_ty, /*end type*/ + i32_ty}; + auto funcType = + LLVM::LLVMFunctionType::get(i32_ty, argsType, /*isVarArg*/ false); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); +} + static StringRef makeNullTerminatedString(StringRef s) { llvm::SmallString<64> ss(s); ss.push_back(0); return ss; } -void llPrintf(StringRef prefix, std::array pid, - std::optional arg, ConversionPatternRewriter &rewriter, - bool hex = false) { +void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, + std::array pid, StringRef prefix, + std::optional arg, bool hex = false) { assert(!prefix.empty() && "printf with empty string not supported"); auto loc = UnknownLoc::get(rewriter.getContext()); @@ -152,30 +185,30 @@ void llPrintf(StringRef prefix, std::array pid, allArgs.push_back(elem); if (arg.has_value()) allArgs.push_back(printfPromoteValue(rewriter, arg.value())); - call(getPrintFuncDecl(rewriter, true), allArgs); + call(getOrAddPrintFuncDecl(rewriter, true), allArgs); } -void llVectorPrint(std::array pid, StringRef prefix, Value ptr, - bool isInteger, bool isSigned, uint32_t bitWidth, - int64_t numElem, bool hex, - ConversionPatternRewriter &rewriter) { +void createRuntimePrintCall(ConversionPatternRewriter &rewriter, + std::array pid, StringRef prefix, + Value ptr, Type dtype, bool hex) { assert(!prefix.empty()); auto loc = UnknownLoc::get(rewriter.getContext()); - Value prefixValue = LLVM::addStringToModule( loc, rewriter, "vectorPrintPrefix_", makeNullTerminatedString(prefix)); SmallVector allArgs; for (auto elem : pid) allArgs.push_back(elem); + allArgs.push_back(prefixValue); allArgs.push_back(ptr); - allArgs.push_back(i32_val(isInteger)); - allArgs.push_back(i32_val(isSigned)); - allArgs.push_back(i32_val(bitWidth)); - allArgs.push_back(i64_val(numElem)); + + allArgs.push_back(i32_val(dtype.getIntOrFloatBitWidth())); + allArgs.push_back(i32_val(dtype.isInteger())); + allArgs.push_back(i32_val(dtype.isSignedInteger())); allArgs.push_back(i32_val(hex)); - call(getPrintFuncDecl(rewriter, false), allArgs); + + call(getOrAddPrintMemrefFuncDecl(rewriter), allArgs); } bool usePrintf(triton::cpu::PrintOp op) { @@ -205,39 +238,24 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { if (usePrintf(op)) { if (op.getNumOperands() == 0) { - llPrintf(op.getPrefix(), pid, std::nullopt, rewriter); - } else { - Value llOpr = adaptor.getOperands()[0]; - llPrintf(op.getPrefix(), pid, llOpr, rewriter, op.getHex()); - } - } else { - Value llOpr = adaptor.getOperands()[0]; - auto vecShapedType = cast(op.getOperands()[0].getType()); - // Currently, we only support 1D vector printing. - if (vecShapedType.getRank() == 1) { - - // To get the pointer of the vector, create an alloca and store it. - auto ptrType = ptr_ty(rewriter.getContext()); - auto ptr = rewriter.create(loc, ptrType, - llOpr.getType(), i32_val(1)); - rewriter.create(loc, llOpr, ptr); - - // TODO: Consider passing an encoded element type information instead of - // booleans and separate bit width. - llVectorPrint(pid, op.getPrefix(), ptr, - vecShapedType.getElementType().isInteger(), - op.getIsSigned()[0], - vecShapedType.getElementTypeBitWidth(), - vecShapedType.getNumElements(), op.getHex(), rewriter); + createRuntimePrintScalarCall(rewriter, pid, op.getPrefix(), + std::nullopt); } else { - // TODO: support 2D+ vector printing. - std::string msg{op.getPrefix()}; - llvm::raw_string_ostream os(msg); - os << "<>"; - llPrintf(msg, pid, std::nullopt, rewriter); + createRuntimePrintScalarCall(rewriter, pid, op.getPrefix(), + adaptor.getOperands()[0], op.getHex()); } + rewriter.eraseOp(op); + return success(); } + // TODO: support 2D+ vector printing. + std::string msg{op.getPrefix()}; + + createRuntimePrintCall( + rewriter, pid, op.getPrefix(), adaptor.getOperands()[0], + cast(op.getVal()[0].getType()).getElementType(), + op.getHex()); + rewriter.eraseOp(op); return success(); } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp index 72a11c510526..83fe858fb139 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -2,6 +2,7 @@ #include "cpu/include/TritonToTritonCPU/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Pass/Pass.h" @@ -33,6 +34,9 @@ class DebugOpsConversionTarget : public ConversionTarget { addLegalDialect(); addLegalOp(); + addLegalOp(); + addLegalOp(); + addLegalOp(); addIllegalOp(); addIllegalOp(); @@ -53,18 +57,44 @@ struct PrintOpConversion : public OpConversionPattern { rewriter.create(loc, op.getPrefix(), op.getHex(), ValueRange{}, llvm::SmallVector{}); - } else { - // triton_cpu.print takes up to one vector or scalar operand. It prints - // each value as a separate print call like the GPU and interpreter. - assert(op.getNumOperands() == op.getIsSigned().size()); - for (size_t i = 0; i < op.getNumOperands(); i++) { - Value opr = op.getOperands()[i]; - llvm::SmallVector isSigned = {op.getIsSigned()[i]}; - // TODO: Consider using memrefs for general N-dimensional vectors. - rewriter.create(loc, op.getPrefix(), op.getHex(), - rewriter.getRemappedValue(opr), - isSigned); + rewriter.eraseOp(op); + return success(); + } + + for (auto operand : op.getOperands()) { + if (!isa(operand.getType())) { + rewriter.create( + loc, op.getPrefix(), op.getHex(), + rewriter.getRemappedValue(operand), false); + continue; + } + + auto tensorTy = cast(operand.getType()); + auto elemTy = tensorTy.getElementType(); + if (isa(elemTy)) { + elemTy = rewriter.getI64Type(); } + MemRefType memRefTy = MemRefType::get(tensorTy.getShape(), elemTy); + + Value allocVal = rewriter.create( + loc, memRefTy, rewriter.getI64IntegerAttr(64)); + + Value vec = rewriter.getRemappedValue(operand); + VectorType vecTy = cast(vec.getType()); + + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(vecTy.getRank(), zeroIdx); + + rewriter.create(loc, vec, allocVal, indices); + + Value allocUnrankedVal = rewriter.create( + loc, UnrankedMemRefType::get(elemTy, memRefTy.getMemorySpace()), + allocVal); + + rewriter.create(loc, op.getPrefix(), op.getHex(), + allocUnrankedVal, false); + + rewriter.create(loc, allocVal); } rewriter.eraseOp(op); diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index dba79828dc5b..06889b732364 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #if defined(_MSC_VER) #define EXPORT __declspec(dllexport) @@ -121,57 +122,57 @@ void printElement(std::stringstream &ss, const void *vec, size_t index, switch (formatInfo.bitWidth) { case 32: printElementHelper(ss, vec, index); - break; + return; case 64: printElementHelper(ss, vec, index); - break; + return; default: llvm_unreachable("Unsupported bitWidth"); } - } else { - if (formatInfo.isSigned) { - switch (formatInfo.bitWidth) { - case 64: - printElementHelper(ss, vec, index); - break; - case 32: - printElementHelper(ss, vec, index); - break; - case 16: - printElementHelper(ss, vec, index); - break; - case 8: - // int8_t is printed as char. - ss << static_cast(static_cast(vec)[index]); - break; - case 1: - printElementHelper(ss, vec, index); - break; - default: - llvm_unreachable("Unsupported bitWidth"); - } - } else { - switch (formatInfo.bitWidth) { - case 64: - printElementHelper(ss, vec, index); - break; - case 32: - printElementHelper(ss, vec, index); - break; - case 16: - printElementHelper(ss, vec, index); - break; - case 8: - ss << static_cast(static_cast(vec)[index]); - break; - case 1: - printElementHelper(ss, vec, index); - break; - default: - llvm_unreachable("Unsupported bitWidth"); - } + } + + if (formatInfo.isSigned) { + switch (formatInfo.bitWidth) { + case 64: + printElementHelper(ss, vec, index); + return; + case 32: + printElementHelper(ss, vec, index); + return; + case 16: + printElementHelper(ss, vec, index); + return; + case 8: + // int8_t is printed as char. + ss << static_cast(static_cast(vec)[index]); + return; + case 1: + printElementHelper(ss, vec, index); + return; + default: + llvm_unreachable("Unsupported bitWidth"); } } + + switch (formatInfo.bitWidth) { + case 64: + printElementHelper(ss, vec, index); + return; + case 32: + printElementHelper(ss, vec, index); + return; + case 16: + printElementHelper(ss, vec, index); + return; + case 8: + ss << static_cast(static_cast(vec)[index]); + return; + case 1: + printElementHelper(ss, vec, index); + return; + default: + llvm_unreachable("Unsupported bitWidth"); + } } void printFormattedElement(std::stringstream &ss, void *vec, size_t index, @@ -201,6 +202,157 @@ void printFormattedElement(std::stringstream &ss, void *vec, size_t index, printElement(ss, vec, index, formatInfo); } } + +template struct RawMemRefDescriptor { + T *allocated; + T *aligned; + intptr_t offset; + intptr_t sizesAndStrides[]; +}; + +template struct MemRefDescriptor { + T *allocated; + T *aligned; + intptr_t offset; + std::vector sizes; + std::vector strides; + int32_t rank; + + MemRefDescriptor(int32_t rank, void *rawDescriptor) : rank(rank) { + auto *rawDesc = static_cast *>(rawDescriptor); + allocated = rawDesc->allocated; + aligned = rawDesc->aligned; + offset = rawDesc->offset; + sizes.resize(rank); + strides.resize(rank); + for (int32_t i = 0; i < rank; i++) { + sizes[i] = rawDesc->sizesAndStrides[i]; + strides[i] = rawDesc->sizesAndStrides[i + rank]; + } + } +}; + +struct UnrankedMemRefType { + int64_t rank; + void *descriptor; +}; + +template +void printToStream(MemRefDescriptor &&desc, std::stringstream &ss, + FormatInfo &partialFormatInfo) { + + if (desc.rank > 1) { + ss << "<>\n"; + return; + } + if (desc.sizes.size() == 0) { + ss << "<>\n"; + } + + T *vec = desc.aligned; + int32_t numElems = desc.sizes[0]; + + FormatInfo formatInfo = getFormatInfo( + vec, partialFormatInfo.isInt, partialFormatInfo.isSigned, + partialFormatInfo.bitWidth, numElems, partialFormatInfo.isHex); + + const size_t header = ss.str().size(); + + if (numElems <= ELEMS_PER_LINE) { + for (int i = 0; i < numElems; i++) { + printFormattedElement(ss, vec, i, formatInfo); + if (i != numElems - 1) + ss << ", "; + } + } else { + // TODO: Too many lines? Omit the middle lines. + for (int i = 0; i < numElems; i++) { + printFormattedElement(ss, vec, i, formatInfo); + if (i == numElems - 1) + break; + if (i % ELEMS_PER_LINE == ELEMS_PER_LINE - 1) { + ss << ",\n" << std::string(header, ' '); + } else { + ss << ", "; + } + } + } + ss << "]\n"; +} + +void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, + int32_t btw, bool isInteger, bool isSignedInteger, + bool asHex) { + + FormatInfo partialFormat{.isInt = isInteger, + .isSigned = isSignedInteger, + .bitWidth = btw, + .isHex = asHex}; + if (!isInteger) { + switch (btw) { + case 64: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + case 32: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + default: + llvm_unreachable("Unsupported bitWidth"); + } + } + if (isSignedInteger) { + switch (btw) { + case 64: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + case 32: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + case 16: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + case 8: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + case 1: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + default: + llvm_unreachable("Unsupported bitWidth"); + } + } + switch (btw) { + case 64: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + case 32: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + case 16: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + case 8: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat); + return; + case 1: + printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat); + return; + default: + llvm_unreachable("Unsupported bitWidth"); + } +} + } // namespace extern "C" { @@ -223,7 +375,7 @@ EXPORT void triton_assert(int32_t pid0, int32_t pid1, int32_t pid2, bool cond, // // TODO: Implement for higher dimension vectors. EXPORT void triton_vector_print(int32_t pid0, int32_t pid1, int32_t pid2, - const char *prefix, void *vec, int32_t isInt, + const char *prefix, void *vec, bool isInt, bool isSigned, int32_t bitWidth, int64_t numElem, bool isHex) { @@ -257,4 +409,17 @@ EXPORT void triton_vector_print(int32_t pid0, int32_t pid1, int32_t pid2, std::cout << ss.str() << std::flush; } +EXPORT void triton_print_unranked_memref(int32_t pid0, int32_t pid1, + int32_t pid2, const char *prefix, + UnrankedMemRefType memref, int32_t btw, + bool isInteger, bool isSignedInteger, + bool asHex) { + std::stringstream ss; + ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix; + + printMemRef(ss, memref.rank, memref.descriptor, btw, isInteger, + isSignedInteger, asHex); + std::cout << ss.str() << std::flush; +} + } // extern "C" From 3d528f7dd0fd771ba6c63baf0ac0384149f90481 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 30 Oct 2024 10:06:48 -0500 Subject: [PATCH 135/165] Small fixes for autotuner on CPU (#172) * Enable num_threads in autotuner and use hooks for tuning on CPU. Signed-off-by: Ilya Enkovich * Add vector-add example for CPU with autotuner. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- python/triton/runtime/autotuner.py | 5 ++- python/tutorials/01-vector-add.py | 51 ++++++++++++++++++++++++++++-- third_party/cpu/backend/driver.py | 8 ++++- 3 files changed, 59 insertions(+), 5 deletions(-) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 339b79529537..c738a1c4aaef 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -269,11 +269,12 @@ class Config: function are args. """ - def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, maxnreg=None, pre_hook=None): + def __init__(self, kwargs, num_warps=4, num_stages=3, num_ctas=1, num_threads=0, maxnreg=None, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps self.num_ctas = num_ctas self.num_stages = num_stages + self.num_threads = num_threads self.maxnreg = maxnreg self.pre_hook = pre_hook @@ -285,6 +286,7 @@ def all_kwargs(self): ("num_warps", self.num_warps), ("num_ctas", self.num_ctas), ("num_stages", self.num_stages), + ("num_threads", self.num_threads), ("maxnreg", self.maxnreg), ) if v is not None } @@ -297,6 +299,7 @@ def __str__(self): res.append(f"num_warps: {self.num_warps}") res.append(f"num_ctas: {self.num_ctas}") res.append(f"num_stages: {self.num_stages}") + res.append(f"num_threads: {self.num_threads}") res.append(f"maxnreg: {self.maxnreg}") return ", ".join(res) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 5c9cf2aa75e8..f6eb176af82a 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -78,6 +78,39 @@ def add_kernel_tiled(x_ptr, # *Pointer* to first input vector. tl.store(output_ptr + offsets, output, mask=mask) +@triton.autotune( + configs=[ + # For small vectors it might be faster to use a single thread instead + # of paying OMP threading overhead, so add a single-threaded option. + # Other options use all available threads. + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 4096}, num_threads=1), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 4096}, num_threads=0), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 8192}, num_threads=0), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 16384}, num_threads=0), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 32768}, num_threads=0), + triton.Config({'TILE_SIZE': 16, 'BLOCK_SIZE': 65536}, num_threads=0), + ], + key=['n_elements'], +) +@triton.jit +def add_kernel_tiled_autotuned(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + TILE_SIZE: tl.constexpr, # Number of elements each iteration should process. + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + for i in range(0, tl.cdiv(BLOCK_SIZE, TILE_SIZE)): + offsets = block_start + i * TILE_SIZE + tl.arange(0, TILE_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + # %% # Let's also declare a helper function to (1) allocate the `z` tensor # and (2) enqueue the above kernel with appropriate grid/block sizes: @@ -125,6 +158,15 @@ def add_tiled_with_st_threshold(x: torch.Tensor, y: torch.Tensor, output): return output +def add_tiled_autotuned(x: torch.Tensor, y: torch.Tensor, output): + if output is None: + output = torch.empty_like(x) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + add_kernel_tiled_autotuned[grid](x, y, output, n_elements) + return output + + # %% # We can now use the above function to compute the element-wise sum of two `torch.tensor` objects and test its correctness: torch.manual_seed(0) @@ -144,13 +186,13 @@ def add_tiled_with_st_threshold(x: torch.Tensor, y: torch.Tensor, output): LINE_VALS = [ 'triton-cpu', 'triton-cpu-hooks', 'triton-cpu-tiled', 'triton-cpu-tiled-hooks', 'triton-cpu-tiled-tuned-hooks', - 'torch-cpu' + 'triton-cpu-tiled-autotuned-hooks', 'torch-cpu' ] LINE_NAMES = [ 'TritonCPU', 'TritonCPU (hooks)', 'TritonCPUTiled', 'TritonCPUTiled (hooks)', 'TritonCPUTiled (tuned, hooks)', - 'TorchCPU' + 'TritonCPUTiled (autotuned, hooks)', 'TorchCPU' ] -LINE_STYLES = [('blue', '--'), ('blue', '-'), ('blue', '-'), ('blue', '-'), ('blue', '-'), ('green', '-')] +LINE_STYLES = [('blue', '--'), ('blue', '-.'), ('red', '-'), ('red', '--'), ('red', '-.'), ('red', ':'), ('green', '-')] if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() @@ -226,6 +268,9 @@ def benchmark(size, provider): elif provider == 'triton-cpu-tiled-tuned-hooks': ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled_with_st_threshold(x, y, output), quantiles=quantiles, measure_time_with_hooks=True) + elif provider == 'triton-cpu-tiled-autotuned-hooks': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: add_tiled_autotuned(x, y, output), quantiles=quantiles, + measure_time_with_hooks=True) gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index f3d3cace794e..3e2e10614fd2 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -435,7 +435,13 @@ def is_active(): def get_benchmarker(self): from triton.testing import do_bench - return do_bench + + def do_bench_cpu(*args, **kwargs): + if not 'measure_time_with_hooks' in kwargs: + kwargs['measure_time_with_hooks'] = True + return do_bench(*args, **kwargs) + + return do_bench_cpu def get_empty_cache_for_benchmark(self): import torch From 4a778e60fb2322218512d76fbc81eb13a264985f Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 30 Oct 2024 12:58:17 -0500 Subject: [PATCH 136/165] Small fixes for clang + macosx (#173) --- python/src/llvm.cc | 29 ++++++++++++++++++++++++----- python/triton/runtime/build.py | 9 ++++++++- third_party/cpu/backend/compiler.py | 2 +- third_party/cpu/backend/driver.py | 10 +++++++++- 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/python/src/llvm.cc b/python/src/llvm.cc index b0c67a5e6ff0..95dfc74d82e5 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -45,13 +45,28 @@ struct BreakStructPhiNodesPass : PassInfoMixin { using namespace llvm; +std::string getDefaultTargerOrProcessTriple() { + // Return process triple iff the default target triple is empty. + std::string triple = llvm::sys::getDefaultTargetTriple(); + if (triple.empty()) { + // host + triple = llvm::sys::getProcessTriple(); + } + return triple; +} + std::unique_ptr createTargetMachine(llvm::Module *module, std::string proc, bool enable_fp_fusion, const std::string &features, bool enable_fast_math = false) { + auto triple = getDefaultTargerOrProcessTriple(); + module->setTargetTriple(triple); std::string error; auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); + if (!target) { + throw std::runtime_error("target lookup error: " + error); + } llvm::TargetOptions opt; bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); if (enable_fp_fusion) @@ -418,10 +433,14 @@ void init_triton_llvm(py::module &&m) { py::arg("enable_fp_fusion") = false); m.def("set_host_target", [](llvm::Module *mod) { - mod->setTargetTriple(llvm::sys::getDefaultTargetTriple()); + auto triple = getDefaultTargerOrProcessTriple(); + mod->setTargetTriple(triple); std::string error; auto target = llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error); + if (!target) { + throw std::runtime_error("target lookup error: " + error); + } std::unique_ptr machine{target->createTargetMachine( mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {}, llvm::Reloc::PIC_)}; @@ -448,10 +467,10 @@ void init_triton_llvm(py::module &&m) { "failed to parse IR: " + error.getMessage() + "lineno: " + std::to_string(error.getLineNo())); } - res = - translateLLVMIRToASM(*module, llvm::sys::getDefaultTargetTriple(), - llvm::sys::getHostCPUName().str(), "", {}, - enable_fp_fusion, false, enable_fast_math); + auto triple = getDefaultTargerOrProcessTriple(); + res = translateLLVMIRToASM(*module, triple, + llvm::sys::getHostCPUName().str(), "", {}, + enable_fp_fusion, false, enable_fast_math); } return py::str(res); }, diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index c44659be31bc..67f3a82a21fd 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -46,9 +46,14 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs] # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047 cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so] + + libraries += ["gcc"] # Use dynamic lookup to load Python library on Mac if system == "Darwin": cc_cmd += ["-undefined", "dynamic_lookup"] + # Don't use libgcc on clang + macos + if "clang" in cc: + libraries.remove("gcc") cc_cmd += [f'-l{lib}' for lib in libraries] cc_cmd += [f"-L{dir}" for dir in library_dirs] cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None] @@ -56,7 +61,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): cc_cmd.extend(["-Wl,-rpath", dir]) # CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag. if src.endswith(".cpp") or src.endswith(".cc"): - cc_cmd += ["-std=c++17", "-fopenmp"] + cc_cmd += ["-std=c++17"] + if not os.environ.get("TRITON_DISABLE_OPENMP", None): + cc_cmd += ["-fopenmp"] if src.endswith(".s"): # This is required to properly parse .file directives cc_cmd += ["-g"] diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index a3e44e97120b..5aabcc051b91 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -263,7 +263,7 @@ def make_so(src, metadata, options): asm_path = os.path.join(tmpdir, "kernel.s") Path(asm_path).write_text(src) lib_dirs = cpu_driver.library_dirs - libs = ["gcc", "m", "TritonCPURuntime", "sleef"] + libs = ["m", "TritonCPURuntime", "sleef"] so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs) with open(so, "rb") as f: return f.read() diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 3e2e10614fd2..cadb76e1229a 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -143,7 +143,9 @@ def format_of(ty): #include #include #include +#ifdef _OPENMP #include +#endif // _OPENMP #include #include #include @@ -238,7 +240,11 @@ def format_of(ty): }} auto all_grids = get_all_grids(gridX, gridY, gridZ); - int max_threads = (num_threads > 0) ? num_threads : omp_get_max_threads(); + int omp_max_threads = 1; + #ifdef _OPENMP + omp_max_threads = omp_get_max_threads(); + #endif // _OPENMP + int max_threads = (num_threads > 0) ? num_threads : omp_max_threads; // Don't pay OMP overhead price when a single thread is used. if (max_threads == 1) {{ @@ -250,7 +256,9 @@ def format_of(ty): }} // For now, use the default chunk size, total iterations / max_threads. +#ifdef _OPENMP #pragma omp parallel for schedule(static) num_threads(max_threads) +#endif // _OPENMP for (size_t i = 0; i < N; ++i) {{ const auto [x, y, z] = all_grids[i]; (*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ); From 30734666c83a8b15f175198feb75c1fbc2187f07 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 30 Oct 2024 16:02:41 -0500 Subject: [PATCH 137/165] Support multi-dimensional tensor prints in CPU runtime. (#174) Signed-off-by: Ilya Enkovich --- third_party/cpu/runtime/cpu_runtime.cpp | 375 +++++++++--------------- 1 file changed, 145 insertions(+), 230 deletions(-) diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index 06889b732364..537441903212 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -34,70 +34,99 @@ struct FormatInfo { bool isHex; }; +template struct RawMemRefDescriptor { + const T *allocated; + const T *aligned; + intptr_t offset; + intptr_t sizesAndStrides[]; +}; + +template class MemRefDescriptor { +private: + const T *data_; + std::vector sizes_; + std::vector strides_; + + MemRefDescriptor(const T *data, std::vector sizes, + std::vector strides) + : data_(data), sizes_(std::move(sizes)), strides_(std::move(strides)) {} + +public: + MemRefDescriptor(int32_t rank, void *rawDescriptor) { + auto *rawDesc = static_cast *>(rawDescriptor); + data_ = rawDesc->aligned + rawDesc->offset; + sizes_.insert(sizes_.begin(), rawDesc->sizesAndStrides, + rawDesc->sizesAndStrides + rank); + strides_.insert(strides_.begin(), rawDesc->sizesAndStrides + rank, + rawDesc->sizesAndStrides + rank * 2); + } + + const T *data() const { return data_; } + + int64_t rank() const { return static_cast(sizes_.size()); } + + int64_t size(int64_t dim) const { return sizes_[dim]; } + + int64_t stride(int64_t dim) const { return strides_[dim]; } + + MemRefDescriptor subView(int64_t idx) const { + assert(rank() > 1); + return {data_ + idx * stride(0), + {sizes_.begin() + 1, sizes_.end()}, + {strides_.begin() + 1, strides_.end()}}; + } +}; + +struct UnrankedMemRefType { + int64_t rank; + void *descriptor; +}; + template -std::pair -computeDigitInfoHelper(const void *array, size_t index) { - T elem = static_cast(array)[index]; - if (elem == 0) +std::pair computeDigitInfo(T val) { + if (val == 0) return {1, false}; - return {static_cast(std::log10(elem >= 0 ? elem : -elem)) + 1, elem < 0}; + int digits = + std::max(static_cast(std::log10(val >= 0 ? val : -val)), 0) + 1; + return {digits, val < 0}; } -std::pair computeDigitInfo(void *vec, bool isInt, bool isSigned, - int32_t bitWidth, size_t index) { - if (isInt == 0) { - if (bitWidth == 32) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 64) - return computeDigitInfoHelper(vec, index); - else - llvm_unreachable("Unsupported bitWidth"); +template +std::tuple computeDigitStats(const MemRefDescriptor &desc) { + int maxIntDigits = 0; + int minIntDigits = std::numeric_limits::max(); + bool hasNegative = false; + + if (desc.rank() == 1) { + const T *data = desc.data(); + int64_t stride = desc.stride(0); + for (int64_t i = 0; i < desc.size(0); ++i) { + auto [digits, negative] = computeDigitInfo(data[i * stride]); + hasNegative |= negative; + maxIntDigits = std::max(maxIntDigits, digits); + minIntDigits = std::min(minIntDigits, digits); + } } else { - if (isSigned) { - if (bitWidth == 64) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 32) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 16) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 8) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 1) - return computeDigitInfoHelper(vec, index); - } else { - if (bitWidth == 64) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 32) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 16) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 8) - return computeDigitInfoHelper(vec, index); - else if (bitWidth == 1) - return computeDigitInfoHelper(vec, index); + for (int64_t i = 0; i < desc.size(0); ++i) { + auto [maxDigits, minDigits, negative] = + computeDigitStats(desc.subView(i)); + hasNegative |= negative; + maxIntDigits = std::max(maxIntDigits, maxDigits); + minIntDigits = std::min(minIntDigits, minDigits); } - printf("bitWidth: %d\n", bitWidth); - llvm_unreachable("Unsupported bitWidth"); } + + return std::make_tuple(maxIntDigits, minIntDigits, hasNegative); } -FormatInfo getFormatInfo(void *vec, bool isInt, bool isSigned, int32_t bitWidth, - int64_t numElem, bool isHex) { +template +FormatInfo getFormatInfo(const MemRefDescriptor &desc, bool isInt, + bool isSigned, int32_t bitWidth, bool isHex) { if (isHex) { assert(bitWidth >= 8 && bitWidth <= 64 && bitWidth % 8 == 0); return {isInt, isSigned, bitWidth, bitWidth / 4, false, false, true}; } - // Compute the max/min widths for pretty printing. - int maxIntDigits = 0; - int minIntDigits = std::numeric_limits::max(); - bool hasNegative = false; - for (int64_t i = 0; i < numElem; ++i) { - auto [digits, negative] = - computeDigitInfo(vec, isInt, isSigned, bitWidth, i); - hasNegative |= negative; - maxIntDigits = std::max(maxIntDigits, digits); - minIntDigits = std::min(minIntDigits, digits); - } + auto [maxIntDigits, minIntDigits, hasNegative] = computeDigitStats(desc); // Fallback to the scientific format for certain cases. bool scientific; if (isInt) { @@ -111,178 +140,98 @@ FormatInfo getFormatInfo(void *vec, bool isInt, bool isSigned, int32_t bitWidth, } template -void printElementHelper(std::stringstream &ss, const void *array, - size_t index) { - ss << static_cast(array)[index]; -} - -void printElement(std::stringstream &ss, const void *vec, size_t index, - const FormatInfo &formatInfo) { - if (!formatInfo.isInt) { - switch (formatInfo.bitWidth) { - case 32: - printElementHelper(ss, vec, index); - return; - case 64: - printElementHelper(ss, vec, index); - return; - default: - llvm_unreachable("Unsupported bitWidth"); - } - } - - if (formatInfo.isSigned) { - switch (formatInfo.bitWidth) { - case 64: - printElementHelper(ss, vec, index); - return; - case 32: - printElementHelper(ss, vec, index); - return; - case 16: - printElementHelper(ss, vec, index); - return; - case 8: - // int8_t is printed as char. - ss << static_cast(static_cast(vec)[index]); - return; - case 1: - printElementHelper(ss, vec, index); - return; - default: - llvm_unreachable("Unsupported bitWidth"); - } - } - - switch (formatInfo.bitWidth) { - case 64: - printElementHelper(ss, vec, index); - return; - case 32: - printElementHelper(ss, vec, index); - return; - case 16: - printElementHelper(ss, vec, index); - return; - case 8: - ss << static_cast(static_cast(vec)[index]); - return; - case 1: - printElementHelper(ss, vec, index); - return; - default: - llvm_unreachable("Unsupported bitWidth"); - } -} - -void printFormattedElement(std::stringstream &ss, void *vec, size_t index, +void printFormattedElement(std::stringstream &ss, T val, const FormatInfo &formatInfo) { // Right now, the GPU's hex float doesn't work correctly. C++ has std:: // hexfloat, but let's consider only hex integers for now. if (formatInfo.isHex && formatInfo.isInt) { ss << "0x" << std::hex << std::setw(formatInfo.maxIntDigits) - << std::setfill('0'); - printElement(ss, vec, index, formatInfo); + << std::setfill('0') << val; return; } int padding = 0; - auto [digits, negative] = computeDigitInfo( - vec, formatInfo.isInt, formatInfo.isSigned, formatInfo.bitWidth, index); + auto [digits, negative] = computeDigitInfo(val); if (!negative && formatInfo.hasNegative) padding++; if (formatInfo.scientific) { ss << std::scientific << std::setw(MAX_FLOAT_WIDTH) - << std::setprecision(FLOAT_PREC) << std::string(padding, ' '); - printElement(ss, vec, index, formatInfo); + << std::setprecision(FLOAT_PREC) << std::string(padding, ' ') << val; } else { padding += formatInfo.maxIntDigits - digits; ss << std::fixed << std::setprecision(FLOAT_PREC) - << std::string(padding, ' '); - printElement(ss, vec, index, formatInfo); + << std::string(padding, ' ') << val; } } -template struct RawMemRefDescriptor { - T *allocated; - T *aligned; - intptr_t offset; - intptr_t sizesAndStrides[]; -}; - -template struct MemRefDescriptor { - T *allocated; - T *aligned; - intptr_t offset; - std::vector sizes; - std::vector strides; - int32_t rank; - - MemRefDescriptor(int32_t rank, void *rawDescriptor) : rank(rank) { - auto *rawDesc = static_cast *>(rawDescriptor); - allocated = rawDesc->allocated; - aligned = rawDesc->aligned; - offset = rawDesc->offset; - sizes.resize(rank); - strides.resize(rank); - for (int32_t i = 0; i < rank; i++) { - sizes[i] = rawDesc->sizesAndStrides[i]; - strides[i] = rawDesc->sizesAndStrides[i + rank]; - } - } -}; +// int8_t is printed as char, so use int16_t instead. +template <> +void printFormattedElement(std::stringstream &ss, int8_t val, + const FormatInfo &formatInfo) { + printFormattedElement(ss, val, formatInfo); +} -struct UnrankedMemRefType { - int64_t rank; - void *descriptor; -}; +template <> +void printFormattedElement(std::stringstream &ss, uint8_t val, + const FormatInfo &formatInfo) { + printFormattedElement(ss, val, formatInfo); +} template -void printToStream(MemRefDescriptor &&desc, std::stringstream &ss, - FormatInfo &partialFormatInfo) { - - if (desc.rank > 1) { - ss << "<>\n"; +void printToStreamRecursive(const MemRefDescriptor &desc, + std::stringstream &ss, const FormatInfo &formatInfo, + const std::string &linePrefix) { + if (desc.rank() > 1) { + ss << "["; + for (int64_t i = 0; i < desc.size(0); ++i) { + printToStreamRecursive(desc.subView(i), ss, formatInfo, linePrefix + " "); + if (i != desc.size(0) - 1) + ss << ",\n" << linePrefix << " "; + } + ss << "]"; return; } - if (desc.sizes.size() == 0) { - ss << "<>\n"; - } - - T *vec = desc.aligned; - int32_t numElems = desc.sizes[0]; - FormatInfo formatInfo = getFormatInfo( - vec, partialFormatInfo.isInt, partialFormatInfo.isSigned, - partialFormatInfo.bitWidth, numElems, partialFormatInfo.isHex); - - const size_t header = ss.str().size(); + const T *data = desc.data(); + int64_t stride = desc.stride(0); + int64_t numElems = desc.size(0); + ss << "["; if (numElems <= ELEMS_PER_LINE) { for (int i = 0; i < numElems; i++) { - printFormattedElement(ss, vec, i, formatInfo); + printFormattedElement(ss, data[i * stride], formatInfo); if (i != numElems - 1) ss << ", "; } } else { // TODO: Too many lines? Omit the middle lines. for (int i = 0; i < numElems; i++) { - printFormattedElement(ss, vec, i, formatInfo); + printFormattedElement(ss, data[i * stride], formatInfo); if (i == numElems - 1) break; if (i % ELEMS_PER_LINE == ELEMS_PER_LINE - 1) { - ss << ",\n" << std::string(header, ' '); + ss << ",\n" << linePrefix << " "; } else { ss << ", "; } } } - ss << "]\n"; + ss << "]"; +} + +template +void printToStream(const MemRefDescriptor &desc, std::stringstream &ss, + const FormatInfo &partialFormatInfo, + const std::string &linePrefix) { + FormatInfo formatInfo = getFormatInfo( + desc, partialFormatInfo.isInt, partialFormatInfo.isSigned, + partialFormatInfo.bitWidth, partialFormatInfo.isHex); + printToStreamRecursive(desc, ss, formatInfo, linePrefix); } void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, - int32_t btw, bool isInteger, bool isSignedInteger, - bool asHex) { + int32_t btw, bool isInteger, bool isSignedInteger, bool asHex, + const std::string &linePrefix) { FormatInfo partialFormat{.isInt = isInteger, .isSigned = isSignedInteger, @@ -292,11 +241,11 @@ void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, switch (btw) { case 64: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 32: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; default: llvm_unreachable("Unsupported bitWidth"); @@ -306,23 +255,23 @@ void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, switch (btw) { case 64: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 32: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 16: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 8: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 1: - printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat, + linePrefix); return; default: llvm_unreachable("Unsupported bitWidth"); @@ -331,22 +280,23 @@ void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, switch (btw) { case 64: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 32: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 16: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 8: printToStream(MemRefDescriptor(rank, descriptor), ss, - partialFormat); + partialFormat, linePrefix); return; case 1: - printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat); + printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat, + linePrefix); return; default: llvm_unreachable("Unsupported bitWidth"); @@ -367,48 +317,11 @@ EXPORT void triton_assert(int32_t pid0, int32_t pid1, int32_t pid2, bool cond, abort(); } -// Print the pid prefix like the GPU ad interpreter. And vectors are printed +// Print the pid prefix like the GPU and interpreter. And vectors are printed // similar to Torch's printing like the following: // (1, 0, 0) x: [ -0.4963, -1.7682, 2.0885, 3.1320, -4.3074, 5.6341, // -6.4901, 7.8964, -8.4556, -9.6323, -10.3489, -11.4017, // -12.0223, 13.1689, 14.2939, -15.5185] -// -// TODO: Implement for higher dimension vectors. -EXPORT void triton_vector_print(int32_t pid0, int32_t pid1, int32_t pid2, - const char *prefix, void *vec, bool isInt, - bool isSigned, int32_t bitWidth, - int64_t numElem, bool isHex) { - - FormatInfo formatInfo = - getFormatInfo(vec, isInt != 0, isSigned != 0, bitWidth, numElem, isHex); - - std::stringstream ss; - ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix << "["; - const size_t header = ss.str().size(); - - if (numElem <= ELEMS_PER_LINE) { - for (int i = 0; i < numElem; i++) { - printFormattedElement(ss, vec, i, formatInfo); - if (i != numElem - 1) - ss << ", "; - } - } else { - // TODO: Too many lines? Omit the middle lines. - for (int i = 0; i < numElem; i++) { - printFormattedElement(ss, vec, i, formatInfo); - if (i == numElem - 1) - break; - if (i % ELEMS_PER_LINE == ELEMS_PER_LINE - 1) { - ss << ",\n" << std::string(header, ' '); - } else { - ss << ", "; - } - } - } - ss << "]\n"; - std::cout << ss.str() << std::flush; -} - EXPORT void triton_print_unranked_memref(int32_t pid0, int32_t pid1, int32_t pid2, const char *prefix, UnrankedMemRefType memref, int32_t btw, @@ -416,9 +329,11 @@ EXPORT void triton_print_unranked_memref(int32_t pid0, int32_t pid1, bool asHex) { std::stringstream ss; ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix; + std::string linePrefix(ss.str().size(), ' '); printMemRef(ss, memref.rank, memref.descriptor, btw, isInteger, - isSignedInteger, asHex); + isSignedInteger, asHex, linePrefix); + ss << "\n"; std::cout << ss.str() << std::flush; } From fbdcbfc3ce834f1de5af29b0c34d06814aee04a5 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Mon, 11 Nov 2024 16:42:25 -0500 Subject: [PATCH 138/165] Fix linux-aarch64 build (#176) Summary: Follow the https://github.com/triton-lang/triton-cpu/pull/165 example, and update one macro for building on linux-aarch64. --- third_party/cpu/triton_cpu.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 9db53e1d0c71..3c3555d3c983 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -20,9 +20,9 @@ #include #include -#ifdef __linux__ +#if defined(__x86_64__) || defined(__i386__) #include -#endif // __linux__ +#endif #include #include From 8f5b2455d53532d63c01266df57e8611b0b77ee7 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 27 Nov 2024 10:13:24 -0600 Subject: [PATCH 139/165] Fix math tests for armv8 (#178) We only use Sleef on !x86 platforms. Sleef APIs are not fully agnostic of the underlying architecture. For example, `Sleef_sinf8_u10` does not exist on Arm. This PR, makes the `MathToVecLibPass` aware of the CPU SIMD architecture by accepting `cpu_features` as new optional argument. No change is expected on x86 side. --- python/src/llvm.cc | 18 +++ python/test/unit/cpu/test_math.py | 23 +++- third_party/cpu/backend/compiler.py | 5 +- .../cpu/include/TritonCPUToLLVM/Passes.h | 3 +- .../cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp | 115 ++++++++++++------ third_party/cpu/triton_cpu.cc | 5 +- 6 files changed, 126 insertions(+), 43 deletions(-) diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 95dfc74d82e5..5977da3e9221 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -595,6 +595,24 @@ void init_triton_llvm(py::module &&m) { if (f.second) res.insert(f.first().str()); } + + // Likely something went wrong with the LLVM feature detection. + if (!res.size()) { + std::string triple = llvm::sys::getProcessTriple(); + // e.g. arm64-apple-darwin24.1.0 + // ^^^^^ + std::size_t pos = triple.find('-'); + if (pos == std::string::npos) { + return res; + } + + std::string arch = triple.substr(0, pos); + if (arch == "aarch64" || arch == "arm64") { + // Safe because NEON is a mandatory feature for aarch64. + res.insert("neon"); // For math tests + } + } + return res; }); } diff --git a/python/test/unit/cpu/test_math.py b/python/test/unit/cpu/test_math.py index 958913e7f9f1..1fd443db967a 100644 --- a/python/test/unit/cpu/test_math.py +++ b/python/test/unit/cpu/test_math.py @@ -5,10 +5,23 @@ import triton import triton.language as tl +from triton._C.libtriton import llvm from triton.language.extra import libdevice from itertools import chain, product +def get_native_vector_size_in_bits(): + """ + Returns the native vector size of the CPU. + Assuming x86 always uses "auto dispatch" with 512-bit vectors for Sleef. + """ + cpu_features = llvm.get_cpu_features() + # TODO support for arm sve w/ VLA + if "neon" in cpu_features: + return 128 + return 512 + + def is_interpreter(): return os.environ.get('TRITON_INTERPRET', '0') == '1' @@ -34,9 +47,13 @@ def check_num_vec_calls(meta, vec_lib, dtype_str, size, is_always_extern=False): # FP16 and BF16 are cast to FP32 for math ops elem_size = 8 if dtype_str == "float64" else 4 data_size = size * elem_size - if data_size > 64: - num_vec_calls = data_size // 64 - elif data_size >= 16: + + vec_size = get_native_vector_size_in_bits() / 8 # bytes + # 128-bit vector is the smallest supported by Sleef for both x86 and arm + smallest_vec_size = 128 / 8 # bytes + if data_size > vec_size: + num_vec_calls = data_size // vec_size + elif data_size >= smallest_vec_size: num_vec_calls = 1 else: num_vec_calls = 1 if is_always_extern else 0 diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 5aabcc051b91..c4b5e6ecd918 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -162,7 +162,8 @@ def make_tttcir(self, mod, metadata, opt): pm.enable_debug() cpu.passes.ttcpuir.add_optimize_masks(pm) passes.common.add_canonicalizer(pm) - convert_bf16_dot_product = self.cpu_arch == "aarch64" and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features + convert_bf16_dot_product = ((self.cpu_arch == "aarch64" or self.cpu_arch == "armv8") + and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features) if convert_bf16_dot_product: use_horizontal_sum = os.getenv("TRITON_CPU_DOT_PROD_HORIZ_SUM", "1") == "1" cpu.passes.ttcpuir.add_convert_dot_product(pm, use_horizontal_sum) @@ -215,7 +216,7 @@ def make_llir(self, src, metadata, options): VecLib.libmvec: {"avx512f"}, } if (vec_lib := options.get_vec_lib()) and vec_lib_requirements[vec_lib] & self.cpu_features: - cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib) + cpu.passes.ttcpuir.add_math_to_vec_lib(pm, vec_lib, self.cpu_features) passes.convert.add_math_to_llvmir(pm) cpu.passes.ttcpuir.add_math_to_libm(pm) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index 6e9892d00206..cc29821c580c 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -32,7 +32,8 @@ std::unique_ptr> createLowerMultiReductionPass(); std::unique_ptr> createAtomicOpsToLLVMPass(); std::unique_ptr> createDebugOpsToLLVMPass(); std::unique_ptr> -createMathToVecLibPass(VecLib lib = VecLib::Sleef); +createMathToVecLibPass(VecLib lib = VecLib::Sleef, + std::set cpu_features = {}); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUToLLVM/Passes.h.inc" diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index b68d5a7473d0..2b1877c1c17b 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -56,14 +56,17 @@ template struct VecOpToFp32 : public OpRewritePattern { }; // Decompose vector operation to single-dimensional vector operations -// with a native AVX512 vector size. +// with a AVX512 for x86 or NEON for ARM. template struct DecomposeToNativeVecs : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + // CPU SIMD vector size in bits + size_t vec_bits; - DecomposeToNativeVecs(MLIRContext *context) - : OpRewritePattern(context) {} + DecomposeToNativeVecs(MLIRContext *context, + size_t native_vec_size_in_bits = 512) + : OpRewritePattern(context), vec_bits(native_vec_size_in_bits) {} LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); @@ -83,7 +86,7 @@ struct DecomposeToNativeVecs : public OpRewritePattern { // vector size. auto shape = vecTy.getShape(); SmallVector newShape(1, 1); - int64_t elemsPerVec = 512 / elemTy.getIntOrFloatBitWidth(); + int64_t elemsPerVec = vec_bits / elemTy.getIntOrFloatBitWidth(); for (int64_t i = shape.size() - 1; i >= 0; --i) { int64_t size = shape[i]; if (newShape.size() > 1) { @@ -330,9 +333,11 @@ struct ExternElementwiseOpConversion template void populatePatternsForOp(RewritePatternSet &patterns, - GetVecFnNameFn getVecFnName) { + GetVecFnNameFn getVecFnName, + size_t vec_size_in_bits = 512) { patterns.add>(patterns.getContext()); - patterns.add>(patterns.getContext()); + patterns.add>(patterns.getContext(), + vec_size_in_bits); patterns.add>(patterns.getContext(), getVecFnName); } @@ -340,8 +345,27 @@ void populatePatternsForOp(RewritePatternSet &patterns, struct MathToVecLibPass : public mlir::triton::cpu::impl::MathToVecLibBase { MathToVecLibPass() = default; + size_t vec_size_in_bits; - explicit MathToVecLibPass(VecLib lib) { this->lib = lib; } + explicit MathToVecLibPass(VecLib lib, std::set cpu_features) { + this->lib = lib; + update_vec_size(cpu_features); + } + + void update_vec_size(std::set &cpu_features) { + // TODO: + // Refactor this as an independent function. + // And improve this to support other x86 SIMD ISAs and also for arm SVE + // (VLA) + vec_size_in_bits = 512; + for (auto feature : cpu_features) { + // Arm NEON is fixed 128-bit SIMD ISA. + if (feature == "neon") { + vec_size_in_bits = 128; + break; + } + } + } void runOnOperation() override { Operation *op = getOperation(); @@ -356,20 +380,20 @@ struct MathToVecLibPass } case VecLib::Sleef: { populateCommonPatterns(patterns); - populatePatternsForOp(patterns, - SleefNameGenerator("expm1")); + populatePatternsForOp( + patterns, SleefNameGenerator("expm1"), vec_size_in_bits); populatePatternsForOp( - patterns, SleefNameGenerator("floor", /*ulp=*/0)); + patterns, SleefNameGenerator("floor", /*ulp=*/0), vec_size_in_bits); populatePatternsForOp( - patterns, SleefNameGenerator("sqrt", /*ulp=*/5)); + patterns, SleefNameGenerator("sqrt", /*ulp=*/5), vec_size_in_bits); populatePatternsForOp( - patterns, SleefNameGenerator("trunc", /*ulp=*/0)); + patterns, SleefNameGenerator("trunc", /*ulp=*/0), vec_size_in_bits); break; } } patterns.add>( - patterns.getContext()); + patterns.getContext(), vec_size_in_bits); patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); @@ -379,26 +403,46 @@ struct MathToVecLibPass template void populateCommonPatterns(RewritePatternSet &patterns) const { - populatePatternsForOp(patterns, VecFnNameGenerator("acos")); - populatePatternsForOp(patterns, VecFnNameGenerator("acosh")); - populatePatternsForOp(patterns, VecFnNameGenerator("asin")); - populatePatternsForOp(patterns, VecFnNameGenerator("asinh")); - populatePatternsForOp(patterns, VecFnNameGenerator("atan")); - populatePatternsForOp(patterns, VecFnNameGenerator("atanh")); - populatePatternsForOp(patterns, VecFnNameGenerator("cbrt")); - populatePatternsForOp(patterns, VecFnNameGenerator("cos")); - populatePatternsForOp(patterns, VecFnNameGenerator("cosh")); - populatePatternsForOp(patterns, VecFnNameGenerator("erf")); - populatePatternsForOp(patterns, VecFnNameGenerator("exp")); - populatePatternsForOp(patterns, VecFnNameGenerator("exp2")); - populatePatternsForOp(patterns, VecFnNameGenerator("log")); - populatePatternsForOp(patterns, VecFnNameGenerator("log2")); - populatePatternsForOp(patterns, VecFnNameGenerator("log10")); - populatePatternsForOp(patterns, VecFnNameGenerator("log1p")); - populatePatternsForOp(patterns, VecFnNameGenerator("sin")); - populatePatternsForOp(patterns, VecFnNameGenerator("sinh")); - populatePatternsForOp(patterns, VecFnNameGenerator("tan")); - populatePatternsForOp(patterns, VecFnNameGenerator("tanh")); + populatePatternsForOp(patterns, VecFnNameGenerator("acos"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("acosh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("asin"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("asinh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("atan"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("atanh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("cbrt"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("cos"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("cosh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("erf"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("exp"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("exp2"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log2"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log10"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("log1p"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("sin"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("sinh"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("tan"), + vec_size_in_bits); + populatePatternsForOp(patterns, VecFnNameGenerator("tanh"), + vec_size_in_bits); } }; @@ -408,8 +452,9 @@ namespace mlir { namespace triton { namespace cpu { -std::unique_ptr> createMathToVecLibPass(VecLib lib) { - return std::make_unique(lib); +std::unique_ptr> +createMathToVecLibPass(VecLib lib, std::set cpu_features) { + return std::make_unique(lib, cpu_features); } } // namespace cpu diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 3c3555d3c983..a412190bbcf8 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -145,8 +145,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_memref_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::createFinalizeMemRefToLLVMConversionPass()); }); - m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib) { - pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib)); + m.def("add_math_to_vec_lib", [](mlir::PassManager &pm, cpu::VecLib lib, + std::set cpu_features) { + pm.addPass(mlir::triton::cpu::createMathToVecLibPass(lib, cpu_features)); }); m.def("add_math_to_libm", [](mlir::PassManager &pm) { pm.addPass(mlir::createConvertMathToLibmPass()); From c0cbf97e0874bb1a7ef33c4a9b39d92aa05a5cd9 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 4 Dec 2024 12:03:45 -0600 Subject: [PATCH 140/165] Allow using local omp with Apple clang (#181) Tested on my M1 Mac as, ``` OMP_NUM_THREADS=8 \ TRITON_LOCAL_LIBOMP_PATH="/site-packages/torch/" \ CC=$(which clang) \ TRITON_CPU_BACKEND=1 \ $(which python3) \ python/tutorials/02-fused-softmax-cpu.py ``` --- python/triton/runtime/build.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 67f3a82a21fd..58a52fb8f82f 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -18,6 +18,16 @@ def quiet(): finally: sys.stdout, sys.stderr = old_stdout, old_stderr + +def _is_apple_clang(): + if platform.system() != "Darwin": + return False + res = subprocess.run(["clang", "--version"], capture_output=True, text=True) + if res.returncode != 0: + return False + return "Apple clang" in res.stdout + + def _build(name, src, srcdir, library_dirs, include_dirs, libraries): suffix = sysconfig.get_config_var('EXT_SUFFIX') system = platform.system() @@ -63,7 +73,20 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): if src.endswith(".cpp") or src.endswith(".cc"): cc_cmd += ["-std=c++17"] if not os.environ.get("TRITON_DISABLE_OPENMP", None): - cc_cmd += ["-fopenmp"] + libomp_path = os.environ.get("TRITON_LOCAL_LIBOMP_PATH", None) + if _is_apple_clang(): + if libomp_path: + cc_cmd += ["-Xclang"] + cc_cmd += ["-fopenmp"] + cc_cmd += [f"-I{libomp_path}/include"] + cc_cmd += [f"-L{libomp_path}/lib"] + cc_cmd += ["-lomp"] + else: + print("Warning: TRITON_LOCAL_LIBOMP_PATH is not set for Apple clang. OpenMP is disabled.") + else: + cc_cmd += ["-fopenmp"] + if libomp_path: + print("Info: Ignoring TRITON_LOCAL_LIBOMP_PATH for non-Apple clang compiler") if src.endswith(".s"): # This is required to properly parse .file directives cc_cmd += ["-g"] From 217591b653ae5bf539a06cc9e5db23e2b887c2e9 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Fri, 6 Dec 2024 00:43:09 -0500 Subject: [PATCH 141/165] Add pytest.mark.cpu to two more already-passing tests (#183) --- python/test/unit/language/test_core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3caf302fc83d..e91e6d90dc82 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4632,6 +4632,7 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): assert "ld.global.v4.b32" not in ptx +@pytest.mark.cpu @pytest.mark.interpreter def test_assume(device): @@ -5666,6 +5667,7 @@ def nested_while(data, countPtr): assert data[0] == 40 +@pytest.mark.cpu def test_constexpr_if_return(device): # Reproducer for #4883, return statement in an if with a constexpr causes # errors when combined with non-trivial control flow graphs From f1a54c4377ca1e212b0272821dc19af85e14e6b2 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Fri, 6 Dec 2024 00:43:32 -0500 Subject: [PATCH 142/165] Move libdevice to third_party (#182) Closes #171. --- .../language/extra => third_party/cpu/language}/cpu/__init__.py | 0 .../language/extra => third_party/cpu/language}/cpu/libdevice.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {python/triton/language/extra => third_party/cpu/language}/cpu/__init__.py (100%) rename {python/triton/language/extra => third_party/cpu/language}/cpu/libdevice.py (100%) diff --git a/python/triton/language/extra/cpu/__init__.py b/third_party/cpu/language/cpu/__init__.py similarity index 100% rename from python/triton/language/extra/cpu/__init__.py rename to third_party/cpu/language/cpu/__init__.py diff --git a/python/triton/language/extra/cpu/libdevice.py b/third_party/cpu/language/cpu/libdevice.py similarity index 100% rename from python/triton/language/extra/cpu/libdevice.py rename to third_party/cpu/language/cpu/libdevice.py From 682cc03f085d587aedc6f17f65cea83d90d0a24c Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 22 Nov 2024 19:52:21 +0100 Subject: [PATCH 143/165] Introduce triton_cpu.DotOp. Signed-off-by: Ilya Enkovich --- .../Dialect/TritonCPU/IR/TritonCPUOps.td | 24 ++++ lib/Dialect/TritonCPU/IR/Ops.cpp | 11 ++ test/TritonCPU/dot-to-amx.mlir | 8 +- third_party/cpu/backend/compiler.py | 1 + .../cpu/include/TritonCPUTransforms/Passes.h | 1 + .../cpu/include/TritonCPUTransforms/Passes.td | 12 ++ .../lib/TritonCPUTransforms/CMakeLists.txt | 3 +- .../ConvertDotOp/ConvertDotGeneric.cpp | 132 ++++++++++++++++++ .../{ => ConvertDotOp}/ConvertDotToAMX.cpp | 106 +++++--------- .../lib/TritonToTritonCPU/ConvertDotOp.cpp | 35 +---- third_party/cpu/triton_cpu.cc | 3 + 11 files changed, 229 insertions(+), 107 deletions(-) create mode 100644 third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotGeneric.cpp rename third_party/cpu/lib/TritonCPUTransforms/{ => ConvertDotOp}/ConvertDotToAMX.cpp (91%) diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index 28eefd383f9c..b58fd9320354 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -166,5 +166,29 @@ def TTC_AssertOp : TTC_Op<"assert", [MemoryEffects<[MemWrite]>]> { let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)"; } +def TTC_DotOp : TTC_Op<"dot", [Pure, + DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", + "d", "c", "$_self">]> { + let summary = "dot"; + + let description = [{Same as tt.dot but on vectors.}]; + + let arguments = ( + ins + TTC_Vector:$a, + TTC_Vector:$b, + TTC_Vector:$c, + DefaultValuedAttr:$inputPrecision, + DefaultValuedAttr:$maxNumImpreciseAcc + ); + + let results = (outs TTC_Vector:$d); + + let assemblyFormat = [{ + $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:` + type($a) `*` type($b) `->` type($d) + }]; +} #endif diff --git a/lib/Dialect/TritonCPU/IR/Ops.cpp b/lib/Dialect/TritonCPU/IR/Ops.cpp index 358ab418ceba..b8523ebcd8ac 100644 --- a/lib/Dialect/TritonCPU/IR/Ops.cpp +++ b/lib/Dialect/TritonCPU/IR/Ops.cpp @@ -26,4 +26,15 @@ void ExternElementwiseOp::getEffects( SideEffects::DefaultResource::get()); } +LogicalResult +DotOp::inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + // type is the same as the accumulator + auto accTy = cast(operands[2].getType()); + inferredReturnTypes.push_back(accTy); + return success(); +} + } // namespace mlir::triton::cpu diff --git a/test/TritonCPU/dot-to-amx.mlir b/test/TritonCPU/dot-to-amx.mlir index 19c476403d2e..b288b10bf62b 100644 --- a/test/TritonCPU/dot-to-amx.mlir +++ b/test/TritonCPU/dot-to-amx.mlir @@ -33,7 +33,7 @@ module { %6 = triton_cpu.extract_memref %1 : > -> memref<32x16xbf16, strided<[16, 1]>> loc(#loc) %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<32x16xbf16, strided<[16, 1]>>, vector<32x16xbf16> loc(#loc) - %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %5, %8, %cst_0 : vector<16x32xbf16>, vector<32x16xbf16> into vector<16x16xf32> loc(#loc) + %9 = triton_cpu.dot %5, %8, %cst_0, inputPrecision = ieee : vector<16x32xbf16> * vector<32x16xbf16> -> vector<16x16xf32> loc(#loc) %10 = triton_cpu.extract_memref %2 : > -> memref<16x16xf32, strided<[16, 1]>> loc(#loc) %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32, strided<[16, 1]>> loc(#loc) @@ -80,7 +80,7 @@ module { %6 = triton_cpu.extract_memref %1 : > -> memref<128x16xi8, strided<[16, 1]>> loc(#loc) %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) %8 = vector.transfer_read %6[%7#0, %7#1], %c0_i8 {in_bounds = [true, true]} : memref<128x16xi8, strided<[16, 1]>>, vector<128x16xi8> loc(#loc) - %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %5, %8, %cst : vector<16x128xi8>, vector<128x16xi8> into vector<16x16xi32> loc(#loc) + %9 = triton_cpu.dot %5, %8, %cst, inputPrecision = ieee : vector<16x128xi8> * vector<128x16xi8> -> vector<16x16xi32> loc(#loc) %10 = triton_cpu.extract_memref %2 : > -> memref<16x16xi32, strided<[16, 1]>> loc(#loc) %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x16xi32>, memref<16x16xi32, strided<[16, 1]>> loc(#loc) @@ -136,7 +136,7 @@ module { %6 = triton_cpu.extract_memref %1 : > -> memref<64x32xbf16, strided<[32, 1]>> loc(#loc) %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<64x32xbf16, strided<[32, 1]>>, vector<64x32xbf16> loc(#loc) - %9 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %5, %8, %cst_0 : vector<16x64xbf16>, vector<64x32xbf16> into vector<16x32xf32> loc(#loc) + %9 = triton_cpu.dot %5, %8, %cst_0, inputPrecision = ieee : vector<16x64xbf16> * vector<64x32xbf16> -> vector<16x32xf32> loc(#loc) %10 = triton_cpu.extract_memref %2 : > -> memref<16x32xf32, strided<[32, 1]>> loc(#loc) %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x32xf32>, memref<16x32xf32, strided<[32, 1]>> loc(#loc) @@ -237,7 +237,7 @@ module { %9 = triton_cpu.extract_memref %arg6 : > -> memref<128x32xf8E5M2, strided<[32, 1]>> loc(#loc) %10:2 = triton_cpu.extract_indices %arg6 : > -> index, index loc(#loc) %11 = vector.transfer_read %9[%10#0, %10#1], %cst {in_bounds = [true, true]} : memref<128x32xf8E5M2, strided<[32, 1]>>, vector<64x32xf8E5M2> loc(#loc) - %12 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %8, %11, %arg4 : vector<64x64xf8E5M2>, vector<64x32xf8E5M2> into vector<64x32xf32> loc(#loc) + %12 = triton_cpu.dot %8, %11, %arg4, inputPrecision = ieee : vector<64x64xf8E5M2> * vector<64x32xf8E5M2> -> vector<64x32xf32> loc(#loc) %13 = tt.advance %arg5, [%c0_i32, %c64_i32] : > loc(#loc) %14 = tt.advance %arg6, [%c64_i32, %c0_i32] : > loc(#loc) scf.yield %12, %13, %14 : vector<64x32xf32>, !tt.ptr>, !tt.ptr> loc(#loc) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index c4b5e6ecd918..5395afe282f8 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -174,6 +174,7 @@ def make_tttcir(self, mod, metadata, opt): amx_fp16 = False amx_bf16 = 'amx-bf16' in self.cpu_features cpu.passes.ttcpuir.add_convert_dot_to_amx(pm, amx_int8, amx_fp16, amx_bf16) + cpu.passes.ttcpuir.add_convert_dot_generic(pm) promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features # We don't have any lowering for mixed precision matmuls, so always use casts for now convert_mixed_precision_matmul = True diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index d9c121cb219b..9b7402d7f0f8 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -36,6 +36,7 @@ createConvertDotProduct(bool useHorizontalSum); std::unique_ptr> createConvertDotToAMX(); std::unique_ptr> createConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16); +std::unique_ptr> createConvertDotGeneric(); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUTransforms/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index c337d2c92eec..a38673595538 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -125,4 +125,16 @@ def ConvertDotToAMX : Pass<"triton-cpu-convert-dot-to-amx", "mlir::ModuleOp"> { "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertDotGeneric : Pass<"triton-cpu-convert-dot-generic", "mlir::ModuleOp"> { + let summary = "Generic convertion of dot product op."; + let description = [{ + This pass is used to lower matmul operations to generic vector code. + }]; + + let constructor = "mlir::triton::cpu::createConvertDotGeneric()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt index 3bf2e3568238..c421e35f8797 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonCPUTransforms + ConvertDotOp/ConvertDotGeneric.cpp + ConvertDotOp/ConvertDotToAMX.cpp ConvertDotProduct.cpp - ConvertDotToAMX.cpp ConvertUnsupportedOps.cpp DecomposeFpConversions.cpp OptimizeMasks.cpp diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotGeneric.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotGeneric.cpp new file mode 100644 index 000000000000..9465e67b36cf --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotGeneric.cpp @@ -0,0 +1,132 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTDOTGENERIC +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class DotConversionTarget : public ConversionTarget { +public: + explicit DotConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addIllegalOp(); + } +}; + +struct DotOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(cpu::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + Value a = op.getA(); + Value b = op.getB(); + Value c = op.getC(); + VectorType aType = cast(a.getType()); + VectorType bType = cast(b.getType()); + VectorType cType = cast(c.getType()); + + uint32_t rank = aType.getRank(); + if (rank == 2) { + auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); + auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); + auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); + auto iteratorTypes = rewriter.getArrayAttr( + {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, + vector::IteratorType::reduction)}); + rewriter.replaceOpWithNewOp( + op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), + iteratorTypes); + return success(); + } else if (rank == 3) { + auto aMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx); + auto bMap = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); + auto cMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx); + auto iteratorTypes = rewriter.getArrayAttr( + {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), + vector::IteratorTypeAttr::get(ctx, + vector::IteratorType::reduction)}); + rewriter.replaceOpWithNewOp( + op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), + iteratorTypes); + return success(); + } + + return failure(); + } + + SmallVector deinterleave(Location loc, ArrayRef vals, + ConversionPatternRewriter &rewriter) const { + SmallVector res; + for (auto &val : vals) { + auto op = rewriter.create(loc, val); + res.push_back(op.getResult(0)); + res.push_back(op.getResult(1)); + } + return res; + } +}; + +struct ConvertDotGeneric + : public triton::cpu::impl::ConvertDotGenericBase { + using ConvertDotGenericBase::ConvertDotGenericBase; + + ConvertDotGeneric() : ConvertDotGenericBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + DotConversionTarget convTarget(*context); + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotGeneric() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp similarity index 91% rename from third_party/cpu/lib/TritonCPUTransforms/ConvertDotToAMX.cpp rename to third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp index aacf150de3b1..73fcd627ea36 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotToAMX.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp @@ -46,7 +46,7 @@ struct AmxBuffer { // Mul[F|I]Op operations. struct AmxDotOpCandidate { // Operation to convert. - vector::ContractionOp op; + cpu::DotOp op; // Available LHS, RHS, and accumulator types are limited in AMX and we might // require additional casts. Here we keep actual element types used by LHS, // RHS, and accumulator in AMX tiles. @@ -80,53 +80,6 @@ struct AmxDotOpCandidate { Operation *origStore = nullptr; }; -bool checkIdxMap(Attribute attr, unsigned int v1, unsigned int v2) { - auto map = cast(attr).getAffineMap(); - return map == - AffineMap::getMultiDimMapWithTargets(3, {v1, v2}, attr.getContext()); -} - -// Return true if specified contraction op is actually a converted DotOp. -bool isDotOp(vector::ContractionOp op) { - // First, check ranks of inputs. - if (cast(op.getLhs().getType()).getRank() != 2 || - cast(op.getRhs().getType()).getRank() != 2 || - cast(op.getAcc().getType()).getRank() != 2) { - LDBG("Drop candidate with rank != 2"); - return false; - } - - // Matmul uses add as a combining function. - if (op.getKind() != vector::CombiningKind::ADD) { - LDBG("Drop candidate with combining function " << op.getKind()); - return false; - } - - // Expect two parallel and one reduction iterators. - auto iterTypes = op.getIteratorTypes(); - if (iterTypes.size() != 3 || - cast(iterTypes[0]).getValue() != - vector::IteratorType::parallel || - cast(iterTypes[1]).getValue() != - vector::IteratorType::parallel || - cast(iterTypes[2]).getValue() != - vector::IteratorType::reduction) { - LDBG("Drop candidate with mismatched iterator types."); - return false; - } - - // Check affine maps. - // TODO: be less restrictive on maps to allow transposed inputs? - auto idxMaps = op.getIndexingMaps(); - if (!checkIdxMap(idxMaps[0], 0, 2) || !checkIdxMap(idxMaps[1], 2, 1) || - !checkIdxMap(idxMaps[2], 0, 1)) { - LDBG("Drop candidate with mismatched affine maps."); - return false; - } - - return true; -} - // Check if input and output types can be handled by AMX (possibly, using // additional casts for input/output). Returns true if AMX usage is possible. // In this case, tile element type fields of the candidate structure are @@ -216,6 +169,19 @@ bool checkElemTypes(Type lhsElemTy, Type rhsElemTy, Type accElemTy, return true; } +// Check input shapes. Currently, support only 2D cases and ignore small +// inputs. +bool checkInputShapes(VectorType lhsTy, VectorType resTy) { + if (lhsTy.getRank() != 2) + return false; + + if (lhsTy.getDimSize(0) < 8 || lhsTy.getDimSize(1) < 8 || + resTy.getDimSize(1) < 8) + return false; + + return true; +} + // Check if accumulator value is updated in a loop and has no other // usages than a dot op, that updates it. Tile loads/stores and casts // for such accumulators can be done outside of the loop. @@ -262,7 +228,7 @@ bool isLoopCarriedAcc(Value acc) { // Return a value that holds the resulting loop carried accumulator value. // It's one of ForOp's results. -Value getResValueForLoopCarriedAcc(vector::ContractionOp op) { +Value getResValueForLoopCarriedAcc(cpu::DotOp op) { Value updAcc = op.getResult(); auto forOp = dyn_cast(op->getParentOp()); auto &use = *updAcc.getUses().begin(); @@ -330,22 +296,16 @@ void findOutputBuffer(Value val, AmxDotOpCandidate &candidate) { // Check if specified ContractionOp can be lowered to AMX operations. // If conversion is possible, then true is returned and candidate // structure is filled with detailed transformation info. -bool isAmxCandidate(vector::ContractionOp op, bool supportInt8, - bool supportFp16, bool supportBf16, - AmxDotOpCandidate &candidate) { +bool isAmxCandidate(cpu::DotOp op, bool supportInt8, bool supportFp16, + bool supportBf16, AmxDotOpCandidate &candidate) { MLIRContext *ctx = op.getContext(); - VectorType lhsTy = cast(op.getLhs().getType()); - VectorType rhsTy = cast(op.getRhs().getType()); - VectorType accTy = cast(op.getAcc().getType()); + VectorType lhsTy = cast(op.getA().getType()); + VectorType rhsTy = cast(op.getB().getType()); + VectorType accTy = cast(op.getC().getType()); VectorType resTy = cast(op.getType()); LDBG("Considering candidate op: " << op); - // Contraction op is very generic. For now, we generate it only as a - // result of DotOp conversion. But still check it's what we expect. - if (!isDotOp(op)) - return false; - // Check if input and output types match available hardware capabilities. // If check is successful then tile element types are filled with types // to use in AMX operations. @@ -354,9 +314,13 @@ bool isAmxCandidate(vector::ContractionOp op, bool supportInt8, supportInt8, supportFp16, supportBf16, candidate)) return false; + // Check input shapes. + if (!checkInputShapes(lhsTy, resTy)) + return false; + candidate.op = op; setupBlockAndTileSizes(lhsTy.getShape(), rhsTy.getShape(), candidate); - candidate.keepAccOnTiles = isLoopCarriedAcc(op.getAcc()); + candidate.keepAccOnTiles = isLoopCarriedAcc(op.getC()); // Can't keep acc in a tile the whole loop right now: // https://github.com/llvm/llvm-project/issues/109481 @@ -697,12 +661,12 @@ void multiplyBlocksPreloadRhs(Location loc, VectorType lhsTileTy, LogicalResult convertCandidate(AmxDotOpCandidate &candidate, PatternRewriter &rewriter) { - vector::ContractionOp op = candidate.op; + cpu::DotOp op = candidate.op; Location loc = op.getLoc(); - VectorType lhsTy = cast(op.getLhs().getType()); - VectorType rhsTy = cast(op.getRhs().getType()); - VectorType accTy = cast(op.getAcc().getType()); - VectorType resTy = cast(op.getResultType()); + VectorType lhsTy = cast(op.getA().getType()); + VectorType rhsTy = cast(op.getB().getType()); + VectorType accTy = cast(op.getC().getType()); + VectorType resTy = cast(op.getResult().getType()); VectorType lhsTileTy = lhsTy.cloneWith(SmallVector({candidate.tileM, candidate.tileK}), candidate.lhsTileElemTy); @@ -726,15 +690,15 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, // Cast input data if required and prepare input buffer. It might be temporary // buffers with stored vectors or the original input memory. - Value lhs = maybeCast(loc, op.getLhs(), candidate.lhsTileElemTy, rewriter); + Value lhs = maybeCast(loc, op.getA(), candidate.lhsTileElemTy, rewriter); AmxBuffer lhsBuf = prepareTensorBuffer(loc, lhs, false, false, true, allocaPoint, rewriter); - Value rhs = maybeCast(loc, op.getRhs(), candidate.rhsTileElemTy, rewriter); + Value rhs = maybeCast(loc, op.getB(), candidate.rhsTileElemTy, rewriter); AmxBuffer rhsBuf = prepareTensorBuffer(loc, rhs, true, false, true, allocaPoint, rewriter); - Value acc = maybeCast(loc, op.getAcc(), candidate.accTileElemTy, rewriter); + Value acc = maybeCast(loc, op.getC(), candidate.accTileElemTy, rewriter); Value accToStore = acc; scf::ForOp forOp; if (candidate.keepAccInBuf || candidate.keepAccOnTiles) { @@ -818,7 +782,7 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, rewriter.replaceAllUsesWith(loopRes, newVal); // For now, just use init value for unused ForOp result instead of // its removal. - rewriter.replaceOp(op, op.getAcc()); + rewriter.replaceOp(op, op.getC()); } else if (candidate.outBuf.empty()) { LDBG("Loading the result to a vector to replace orig op result."); Value newVal = rewriter.create( @@ -852,7 +816,7 @@ struct ConvertDotToAMX ModuleOp mod = getOperation(); SmallVector candidates; - mod->walk([this, &candidates](vector::ContractionOp op) { + mod->walk([this, &candidates](cpu::DotOp op) { AmxDotOpCandidate candidate; if (isAmxCandidate(op, convertInt8, convertFp16, convertBf16, candidate)) { diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp index 06cfb0d834d0..4672d3cca3f0 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDotOp.cpp @@ -60,39 +60,12 @@ struct DotOpConversion : public OpConversionPattern { auto cType = cast(c.getType()); assert(aType.getRank() == bType.getRank() && bType.getRank() == cType.getRank() && + (aType.getRank() == 2 || aType.getRank() == 3) && "Mixed ranks, not 2d or 3d matmul, unknown type of op"); - uint32_t rank = aType.getRank(); - if (rank == 2) { - auto aMap = AffineMap::getMultiDimMapWithTargets(3, {0, 2}, ctx); - auto bMap = AffineMap::getMultiDimMapWithTargets(3, {2, 1}, ctx); - auto cMap = AffineMap::getMultiDimMapWithTargets(3, {0, 1}, ctx); - auto iteratorTypes = rewriter.getArrayAttr( - {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, - vector::IteratorType::reduction)}); - rewriter.replaceOpWithNewOp( - op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), - iteratorTypes); - return success(); - } else if (rank == 3) { - auto aMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx); - auto bMap = AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx); - auto cMap = AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx); - auto iteratorTypes = rewriter.getArrayAttr( - {vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel), - vector::IteratorTypeAttr::get(ctx, - vector::IteratorType::reduction)}); - rewriter.replaceOpWithNewOp( - op, a, b, c, rewriter.getAffineMapArrayAttr({aMap, bMap, cMap}), - iteratorTypes); - return success(); - } - - return failure(); + rewriter.replaceOpWithNewOp(op, a, b, c, op.getInputPrecision(), + op.getMaxNumImpreciseAcc()); + return success(); } }; diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index a412190bbcf8..40f56204c427 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -88,6 +88,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { pm.addPass(mlir::triton::cpu::createConvertDotToAMX( convertInt8, convertFp16, convertBf16)); }); + m.def("add_convert_dot_generic", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDotGeneric()); + }); m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm, bool promote_bf16_to_fp32, bool convert_mixed_precision_matmul, bool promote_lib_math_to_fp32) { From 74a3488252b70f398568b82dd41d1aace97156ae Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 2 Oct 2024 21:25:18 +0000 Subject: [PATCH 144/165] Fixes to use the latest LLVM. Signed-off-by: Ilya Enkovich --- third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp index 1b078f20020e..728d353592bb 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/TypeConverter.cpp @@ -16,8 +16,7 @@ TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { }); addArgumentMaterialization([&](OpBuilder &builder, Type type, - ValueRange inputs, - Location loc) -> std::optional { + ValueRange inputs, Location loc) -> Value { if (isa(type)) return builder.create(loc, type, inputs) .getResult(0); @@ -31,14 +30,14 @@ TritonToTritonCPUTypeConverter::TritonToTritonCPUTypeConverter() { // Converted ops produce vectors instead of tensors. Provide conversion // here for users. addSourceMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) -> std::optional { + Location loc) -> Value { return builder.create(loc, type, inputs) .getResult(0); }); // Provide conversion for vector users. addTargetMaterialization([&](OpBuilder &builder, Type type, ValueRange inputs, - Location loc) -> std::optional { + Location loc) -> Value { if (isa(type)) return builder.create(loc, type, inputs) .getResult(0); From ad46864e27db71c0043645cf8dd7fcef111af9e2 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 6 Dec 2024 18:04:26 +0000 Subject: [PATCH 145/165] Fix pybind11 build issue for TritonCPU. Signed-off-by: Ilya Enkovich --- third_party/cpu/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index 59d0f5c53d46..c62a4cda03ad 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -4,7 +4,7 @@ add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) - target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation) + target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation PRIVATE Python3::Module pybind11::headers) endif() add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) From 8ca15dad0c5890fa73cec6f5bd179d96bc4995d8 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 2 Oct 2024 21:25:56 +0000 Subject: [PATCH 146/165] Use mlir::amx::TileType. Signed-off-by: Ilya Enkovich --- test/TritonCPU/dot-to-amx.mlir | 154 ++++++++---------- .../cpu/lib/TritonCPUToLLVM/TypeConverter.cpp | 5 + .../ConvertDotOp/ConvertDotToAMX.cpp | 44 ++--- 3 files changed, 98 insertions(+), 105 deletions(-) diff --git a/test/TritonCPU/dot-to-amx.mlir b/test/TritonCPU/dot-to-amx.mlir index b288b10bf62b..da501849f723 100644 --- a/test/TritonCPU/dot-to-amx.mlir +++ b/test/TritonCPU/dot-to-amx.mlir @@ -6,16 +6,13 @@ // CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<16x32xbf16> // CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x16xf32, strided<[16, 1]>> // CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index -// CHECK: %[[ACC:.+]] = amx.tile_zero : vector<16x16xf32> +// CHECK: %[[ACC:.+]] = amx.tile_zero : !amx.tile<16x16xf32> // CHECK-NEXT: %[[LHS:.+]] = amx.tile_load %3[%4#0, %4#1] // CHECK-NEXT: %[[RHS:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] -// CHECK-NEXT: %[[RES:.+]] = amx.tile_mulf %[[LHS]], %[[RHS]], %[[ACC]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES]] : memref<16x16xf32, strided<[16, 1]>>, vector<16x16xf32> +// CHECK-NEXT: %[[RES:.+]] = amx.tile_mulf %[[LHS]], %[[RHS]], %[[ACC]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES]] : memref<16x16xf32, strided<[16, 1]>>, !amx.tile<16x16xf32> #loc = loc(unknown) -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> module { tt.func public @test_single_mulf(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) @@ -49,20 +46,17 @@ module { // CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<32x64xi8> // CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x16xi32, strided<[16, 1]>> // CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index -// CHECK: %[[ACC:.+]] = amx.tile_zero : vector<16x16xi32> +// CHECK: %[[ACC:.+]] = amx.tile_zero : !amx.tile<16x16xi32> // CHECK-NEXT: %[[LHS1:.+]] = amx.tile_load %3[%4#0, %4#1] // CHECK-NEXT: %[[RHS1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] -// CHECK-NEXT: %[[RES1:.+]] = amx.tile_muli %[[LHS1]], %[[RHS1]], %[[ACC]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> +// CHECK-NEXT: %[[RES1:.+]] = amx.tile_muli %[[LHS1]], %[[RHS1]], %[[ACC]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> // CHECK-NEXT: %[[IDX1:.+]] = arith.addi %4#1, %c64{{.*}} : index -// CHECK-NEXT: %[[LHS2:.+]] = amx.tile_load %3[%4#0, %[[IDX1]]] : memref<16x128xi8, strided<[128, 1]>> into vector<16x64xi8> -// CHECK-NEXT: %[[RHS2:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xi8> into vector<16x64xi8> -// CHECK-NEXT: %[[RES2:.+]] = amx.tile_muli %[[LHS2]], %[[RHS2]], %[[RES1]] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> -// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES2]] : memref<16x16xi32, strided<[16, 1]>>, vector<16x16xi32> +// CHECK-NEXT: %[[LHS2:.+]] = amx.tile_load %3[%4#0, %[[IDX1]]] : memref<16x128xi8, strided<[128, 1]>> into !amx.tile<16x64xi8> +// CHECK-NEXT: %[[RHS2:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xi8> into !amx.tile<16x64xi8> +// CHECK-NEXT: %[[RES2:.+]] = amx.tile_muli %[[LHS2]], %[[RHS2]], %[[RES1]] : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES2]] : memref<16x16xi32, strided<[16, 1]>>, !amx.tile<16x16xi32> #loc = loc(unknown) -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> module { tt.func public @test_single_tile_two_muli(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { %c0_i8 = arith.constant 0 : i8 loc(#loc) @@ -97,27 +91,24 @@ module { // CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<32x64xbf16> // CHECK: %[[OUT_MEMREF:.+]] = triton_cpu.extract_memref %2 : > -> memref<16x32xf32, strided<[32, 1]>> // CHECK-NEXT: %[[OUT_INDICES:.+]]:2 = triton_cpu.extract_indices %2 : > -> index, index -// CHECK: %[[ACC1:.+]] = amx.tile_zero : vector<16x16xf32> -// CHECK-NEXT: %[[ACC2:.+]] = amx.tile_zero : vector<16x16xf32> -// CHECK-NEXT: %[[LHS1:.+]] = amx.tile_load %3[%4#0, %4#1] : memref<16x64xbf16, strided<[64, 1]>> into vector<16x32xbf16> -// CHECK-NEXT: %[[RHS1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RES1:.+]] = amx.tile_mulf %[[LHS1]], %[[RHS1]], %[[ACC1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK: %[[RHS2:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RES2:.+]] = amx.tile_mulf %[[LHS1]], %[[RHS2]], %[[ACC2]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK: %[[ACC1:.+]] = amx.tile_zero : !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC2:.+]] = amx.tile_zero : !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS1:.+]] = amx.tile_load %3[%4#0, %4#1] : memref<16x64xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES1:.+]] = amx.tile_mulf %[[LHS1]], %[[RHS1]], %[[ACC1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK: %[[RHS2:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES2:.+]] = amx.tile_mulf %[[LHS1]], %[[RHS2]], %[[ACC2]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> // CHECK: %[[IDX1:.+]] = arith.addi %4#1, %c32{{.*}} : index -// CHECK-NEXT: %[[LHS2:.+]] = amx.tile_load %3[%4#0, %[[IDX1]]] : memref<16x64xbf16, strided<[64, 1]>> into vector<16x32xbf16> -// CHECK: %[[RHS3:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RES3:.+]] = amx.tile_mulf %[[LHS2]], %[[RHS3]], %[[RES1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES3]] : memref<16x32xf32, strided<[32, 1]>>, vector<16x16xf32> -// CHECK: %[[RHS4:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RES4:.+]] = amx.tile_mulf %[[LHS2]], %[[RHS4]], %[[RES2]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK-NEXT: %[[LHS2:.+]] = amx.tile_load %3[%4#0, %[[IDX1]]] : memref<16x64xbf16, strided<[64, 1]>> into !amx.tile<16x32xbf16> +// CHECK: %[[RHS3:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES3:.+]] = amx.tile_mulf %[[LHS2]], %[[RHS3]], %[[RES1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[OUT_INDICES]]#1], %[[RES3]] : memref<16x32xf32, strided<[32, 1]>>, !amx.tile<16x16xf32> +// CHECK: %[[RHS4:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES4:.+]] = amx.tile_mulf %[[LHS2]], %[[RHS4]], %[[RES2]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> // CHECK: %[[IDX2:.+]] = arith.addi %[[OUT_INDICES]]#1, %c16{{.*}} : index -// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[IDX2]]], %[[RES4]] : memref<16x32xf32, strided<[32, 1]>>, vector<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[OUT_MEMREF]][%[[OUT_INDICES]]#0, %[[IDX2]]], %[[RES4]] : memref<16x32xf32, strided<[32, 1]>>, !amx.tile<16x16xf32> #loc = loc(unknown) -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> module { tt.func public @test_two_tiles_four_mulf(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) @@ -161,60 +152,57 @@ module { // CHECK-NEXT: vector.transfer_write %[[LHS1]], %[[LHS_BUF]][%c0{{.*}}, %c0{{.*}}] {in_bounds = [true, true]} : vector<64x64xbf16>, memref<64x64xbf16> // CHECK-NEXT: %[[RHS1:.+]] = arith.extf %[[RHS]] : vector<64x32xf8E5M2> to vector<64x32xbf16> // CHECK-COUNT-32: vector.store %{{.+}}, %[[RHS_BUF]][%{{.+}}, %{{.+}}] : memref<32x64xbf16>, vector<64xbf16> -// CHECK-NEXT: %[[ACC_0_0:.+]] = amx.tile_load %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<64x32xf32> into vector<16x16xf32> -// CHECK-NEXT: %[[ACC_0_1:.+]] = amx.tile_load %[[ACC_BUF]][%c0{{.*}}, %c16{{.*}}] : memref<64x32xf32> into vector<16x16xf32> -// CHECK-NEXT: %[[ACC_1_0:.+]] = amx.tile_load %[[ACC_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<64x32xf32> into vector<16x16xf32> -// CHECK-NEXT: %[[ACC_1_1:.+]] = amx.tile_load %[[ACC_BUF]][%c16{{.*}}, %c16{{.*}}] : memref<64x32xf32> into vector<16x16xf32> -// CHECK-NEXT: %[[LHS_0_0:.+]] = amx.tile_load %[[LHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[LHS_1_0:.+]] = amx.tile_load %[[LHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RHS_0_0:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[TMP_0_0:.+]] = amx.tile_mulf %[[LHS_0_0]], %[[RHS_0_0]], %[[ACC_0_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: %[[TMP_1_0:.+]] = amx.tile_mulf %[[LHS_1_0]], %[[RHS_0_0]], %[[ACC_1_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: %[[RHS_0_1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[TMP_0_1:.+]] = amx.tile_mulf %[[LHS_0_0]], %[[RHS_0_1]], %[[ACC_0_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: %[[TMP_1_1:.+]] = amx.tile_mulf %[[LHS_1_0]], %[[RHS_0_1]], %[[ACC_1_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: %[[LHS_0_1:.+]] = amx.tile_load %[[LHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[LHS_1_1:.+]] = amx.tile_load %[[LHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RHS_1_0:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RES_0_0:.+]] = amx.tile_mulf %[[LHS_0_1]], %[[RHS_1_0]], %[[TMP_0_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}], %[[RES_0_0]] : memref<64x32xf32>, vector<16x16xf32> -// CHECK-NEXT: %[[RES_1_0:.+]] = amx.tile_mulf %[[LHS_1_1]], %[[RHS_1_0]], %[[TMP_1_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c16{{.*}}, %c0{{.*}}], %[[RES_1_0]] : memref<64x32xf32>, vector<16x16xf32> -// CHECK-NEXT: %[[RHS_1_1:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RES_0_1:.+]] = amx.tile_mulf %[[LHS_0_1]], %[[RHS_1_1]], %[[TMP_0_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c0{{.*}}, %c16{{.*}}], %[[RES_0_1]] : memref<64x32xf32>, vector<16x16xf32> -// CHECK-NEXT: %[[RES_1_1:.+]] = amx.tile_mulf %[[LHS_1_1]], %[[RHS_1_1]], %[[TMP_1_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c16{{.*}}, %c16{{.*}}], %[[RES_1_1]] : memref<64x32xf32>, vector<16x16xf32> -// CHECK-NEXT: %[[ACC_2_0:.+]] = amx.tile_load %[[ACC_BUF]][%c32{{.*}}, %c0{{.*}}] : memref<64x32xf32> into vector<16x16xf32> -// CHECK-NEXT: %[[ACC_2_1:.+]] = amx.tile_load %[[ACC_BUF]][%c32{{.*}}, %c16{{.*}}] : memref<64x32xf32> into vector<16x16xf32> -// CHECK-NEXT: %[[ACC_3_0:.+]] = amx.tile_load %[[ACC_BUF]][%c48{{.*}}, %c0{{.*}}] : memref<64x32xf32> into vector<16x16xf32> -// CHECK-NEXT: %[[ACC_3_1:.+]] = amx.tile_load %[[ACC_BUF]][%c48{{.*}}, %c16{{.*}}] : memref<64x32xf32> into vector<16x16xf32> -// CHECK-NEXT: %[[LHS_2_0:.+]] = amx.tile_load %[[LHS_BUF]][%c32{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[LHS_3_0:.+]] = amx.tile_load %[[LHS_BUF]][%c48{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RHS_0_0:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[TMP_2_0:.+]] = amx.tile_mulf %[[LHS_2_0]], %[[RHS_0_0]], %[[ACC_2_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: %[[TMP_3_0:.+]] = amx.tile_mulf %[[LHS_3_0]], %[[RHS_0_0]], %[[ACC_3_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: %[[RHS_0_1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[TMP_2_1:.+]] = amx.tile_mulf %[[LHS_2_0]], %[[RHS_0_1]], %[[ACC_2_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: %[[TMP_3_1:.+]] = amx.tile_mulf %[[LHS_3_0]], %[[RHS_0_1]], %[[ACC_3_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: %[[LHS_2_1:.+]] = amx.tile_load %[[LHS_BUF]][%c32{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[LHS_3_1:.+]] = amx.tile_load %[[LHS_BUF]][%c48{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RHS_1_0:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RES_2_0:.+]] = amx.tile_mulf %[[LHS_2_1]], %[[RHS_1_0]], %[[TMP_2_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c32{{.*}}, %c0{{.*}}], %[[RES_2_0]] : memref<64x32xf32>, vector<16x16xf32> -// CHECK-NEXT: %[[RES_3_0:.+]] = amx.tile_mulf %[[LHS_3_1]], %[[RHS_1_0]], %[[TMP_3_0]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c48{{.*}}, %c0{{.*}}], %[[RES_3_0]] : memref<64x32xf32>, vector<16x16xf32> -// CHECK-NEXT: %[[RHS_1_1:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into vector<16x32xbf16> -// CHECK-NEXT: %[[RES_2_1:.+]] = amx.tile_mulf %[[LHS_2_1]], %[[RHS_1_1]], %[[TMP_2_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c32{{.*}}, %c16{{.*}}], %[[RES_2_1]] : memref<64x32xf32>, vector<16x16xf32> -// CHECK-NEXT: %[[RES_3_1:.+]] = amx.tile_mulf %[[LHS_3_1]], %[[RHS_1_1]], %[[TMP_3_1]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> -// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c48{{.*}}, %c16{{.*}}], %[[RES_3_1]] : memref<64x32xf32>, vector<16x16xf32> +// CHECK-NEXT: %[[ACC_0_0:.+]] = amx.tile_load %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_0_1:.+]] = amx.tile_load %[[ACC_BUF]][%c0{{.*}}, %c16{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_1_0:.+]] = amx.tile_load %[[ACC_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_1_1:.+]] = amx.tile_load %[[ACC_BUF]][%c16{{.*}}, %c16{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS_0_0:.+]] = amx.tile_load %[[LHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[LHS_1_0:.+]] = amx.tile_load %[[LHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS_0_0:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[TMP_0_0:.+]] = amx.tile_mulf %[[LHS_0_0]], %[[RHS_0_0]], %[[ACC_0_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[TMP_1_0:.+]] = amx.tile_mulf %[[LHS_1_0]], %[[RHS_0_0]], %[[ACC_1_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RHS_0_1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[TMP_0_1:.+]] = amx.tile_mulf %[[LHS_0_0]], %[[RHS_0_1]], %[[ACC_0_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[TMP_1_1:.+]] = amx.tile_mulf %[[LHS_1_0]], %[[RHS_0_1]], %[[ACC_1_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS_0_1:.+]] = amx.tile_load %[[LHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[LHS_1_1:.+]] = amx.tile_load %[[LHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS_1_0:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES_0_0:.+]] = amx.tile_mulf %[[LHS_0_1]], %[[RHS_1_0]], %[[TMP_0_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}], %[[RES_0_0]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RES_1_0:.+]] = amx.tile_mulf %[[LHS_1_1]], %[[RHS_1_0]], %[[TMP_1_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c16{{.*}}, %c0{{.*}}], %[[RES_1_0]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RHS_1_1:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES_0_1:.+]] = amx.tile_mulf %[[LHS_0_1]], %[[RHS_1_1]], %[[TMP_0_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c0{{.*}}, %c16{{.*}}], %[[RES_0_1]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RES_1_1:.+]] = amx.tile_mulf %[[LHS_1_1]], %[[RHS_1_1]], %[[TMP_1_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c16{{.*}}, %c16{{.*}}], %[[RES_1_1]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_2_0:.+]] = amx.tile_load %[[ACC_BUF]][%c32{{.*}}, %c0{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_2_1:.+]] = amx.tile_load %[[ACC_BUF]][%c32{{.*}}, %c16{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_3_0:.+]] = amx.tile_load %[[ACC_BUF]][%c48{{.*}}, %c0{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[ACC_3_1:.+]] = amx.tile_load %[[ACC_BUF]][%c48{{.*}}, %c16{{.*}}] : memref<64x32xf32> into !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS_2_0:.+]] = amx.tile_load %[[LHS_BUF]][%c32{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[LHS_3_0:.+]] = amx.tile_load %[[LHS_BUF]][%c48{{.*}}, %c0{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS_0_0:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[TMP_2_0:.+]] = amx.tile_mulf %[[LHS_2_0]], %[[RHS_0_0]], %[[ACC_2_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[TMP_3_0:.+]] = amx.tile_mulf %[[LHS_3_0]], %[[RHS_0_0]], %[[ACC_3_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RHS_0_1:.+]] = amx.tile_load %[[RHS_BUF]][%c0{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[TMP_2_1:.+]] = amx.tile_mulf %[[LHS_2_0]], %[[RHS_0_1]], %[[ACC_2_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[TMP_3_1:.+]] = amx.tile_mulf %[[LHS_3_0]], %[[RHS_0_1]], %[[ACC_3_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[LHS_2_1:.+]] = amx.tile_load %[[LHS_BUF]][%c32{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[LHS_3_1:.+]] = amx.tile_load %[[LHS_BUF]][%c48{{.*}}, %c32{{.*}}] : memref<64x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RHS_1_0:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c0{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES_2_0:.+]] = amx.tile_mulf %[[LHS_2_1]], %[[RHS_1_0]], %[[TMP_2_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c32{{.*}}, %c0{{.*}}], %[[RES_2_0]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RES_3_0:.+]] = amx.tile_mulf %[[LHS_3_1]], %[[RHS_1_0]], %[[TMP_3_0]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c48{{.*}}, %c0{{.*}}], %[[RES_3_0]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RHS_1_1:.+]] = amx.tile_load %[[RHS_BUF]][%c16{{.*}}, %c32{{.*}}] : memref<32x64xbf16> into !amx.tile<16x32xbf16> +// CHECK-NEXT: %[[RES_2_1:.+]] = amx.tile_mulf %[[LHS_2_1]], %[[RHS_1_1]], %[[TMP_2_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c32{{.*}}, %c16{{.*}}], %[[RES_2_1]] : memref<64x32xf32>, !amx.tile<16x16xf32> +// CHECK-NEXT: %[[RES_3_1:.+]] = amx.tile_mulf %[[LHS_3_1]], %[[RHS_1_1]], %[[TMP_3_1]] : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32> +// CHECK-NEXT: amx.tile_store %[[ACC_BUF]][%c48{{.*}}, %c16{{.*}}], %[[RES_3_1]] : memref<64x32xf32>, !amx.tile<16x16xf32> // CHECK: %[[RES:.+]] = vector.transfer_read %[[ACC_BUF]][%c0{{.*}}, %c0{{.*}}], %{{.*}} {in_bounds = [true, true]} : memref<64x32xf32>, vector<64x32xf32> #loc = loc(unknown) -#map = affine_map<(d0, d1, d2) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> module { tt.func public @test_loop_acc_two_blocks(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { %cst = arith.constant 0.000000e+00 : f8E5M2 loc(#loc) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp index 821ea6f954b2..f9a02592f5d5 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/TypeConverter.cpp @@ -1,5 +1,7 @@ #include "TypeConverter.h" +#include "mlir/Dialect/AMX/AMXDialect.h" + using namespace mlir; using namespace mlir::triton; @@ -13,6 +15,9 @@ TritonCPUToLLVMTypeConverter::TritonCPUToLLVMTypeConverter( addConversion([this](RankedTensorType type) -> std::optional { return convertTritonTensorType(type); }); + addConversion([&](amx::TileType type) { + return LLVM::LLVMX86AMXType::get(type.getContext()); + }); } Type TritonCPUToLLVMTypeConverter::convertTritonPointerType( diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp index 73fcd627ea36..23f16944de41 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp @@ -379,7 +379,7 @@ Value getInitAccValue(Value val) { return forOp.getInitArgs()[initValIdx]; } -VectorType getSwizzledRhsTileType(VectorType origTileType) { +template T getSwizzledRhsTileType(T origTileType) { int64_t rowsPerGroup = 32 / origTileType.getElementTypeBitWidth(); SmallVector shape({origTileType.getDimSize(0) / rowsPerGroup, origTileType.getDimSize(1) * rowsPerGroup}); @@ -518,7 +518,7 @@ Value shiftIndex(Location loc, Value index, int64_t offs, } SmallVector shiftIndices(Location loc, ArrayRef indices, - VectorType tileTy, int64_t tilesInBlockM, + amx::TileType tileTy, int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, int64_t blockN, int64_t tileM, int64_t tileN, PatternRewriter &rewriter) { @@ -530,7 +530,7 @@ SmallVector shiftIndices(Location loc, ArrayRef indices, shiftIndex(loc, indices[1], tileOffsN, rewriter)}; } -Value loadTile(Location loc, VectorType tileTy, const AmxBuffer &buf, +Value loadTile(Location loc, amx::TileType tileTy, const AmxBuffer &buf, int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, int64_t blockN, int64_t tileM, int64_t tileN, PatternRewriter &rewriter) { @@ -540,10 +540,10 @@ Value loadTile(Location loc, VectorType tileTy, const AmxBuffer &buf, return rewriter.create(loc, tileTy, buf.memRef, indices); } -void storeTile(Location loc, VectorType tileTy, Value val, const AmxBuffer &buf, - int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, - int64_t blockN, int64_t tileM, int64_t tileN, - PatternRewriter &rewriter) { +void storeTile(Location loc, amx::TileType tileTy, Value val, + const AmxBuffer &buf, int64_t tilesInBlockM, + int64_t tilesInBlockN, int64_t blockM, int64_t blockN, + int64_t tileM, int64_t tileN, PatternRewriter &rewriter) { auto indices = shiftIndices(loc, buf.indices, tileTy, tilesInBlockM, tilesInBlockN, blockM, blockN, tileM, tileN, rewriter); @@ -551,7 +551,7 @@ void storeTile(Location loc, VectorType tileTy, Value val, const AmxBuffer &buf, } SmallVector> -loadBlockTiles(Location loc, VectorType tileTy, const AmxBuffer &buf, +loadBlockTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, int64_t blockN, PatternRewriter &rewriter) { SmallVector> res(tilesInBlockM); @@ -570,7 +570,7 @@ loadBlockTiles(Location loc, VectorType tileTy, const AmxBuffer &buf, // Move acc to a tile for the whole loop. It might be loads from memory or // zero tiles. SmallVector> -moveLoopAccToTiles(Location loc, VectorType tileTy, const AmxBuffer &buf, +moveLoopAccToTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, int64_t tilesInBlockM, int64_t tilesInBlockN, PatternRewriter &rewriter) { LDBG("Loading accumulator to tiles before the loop."); @@ -588,8 +588,8 @@ moveLoopAccToTiles(Location loc, VectorType tileTy, const AmxBuffer &buf, // Multiply two blocks. LHS block is preloaded to tiles with the following // iteration over RHS. Accumulator values are updated in accTiles. // Optionally, results can also be stored to accBuf. -void multiplyBlocksPreloadLhs(Location loc, VectorType lhsTileTy, - VectorType rhsTileTy, VectorType accTileTy, +void multiplyBlocksPreloadLhs(Location loc, amx::TileType lhsTileTy, + amx::TileType rhsTileTy, amx::TileType accTileTy, const AmxBuffer &lhsBuf, const AmxBuffer &rhsBuf, const AmxBuffer &accBuf, int64_t blockM, int64_t blockN, int64_t blockK, @@ -624,8 +624,8 @@ void multiplyBlocksPreloadLhs(Location loc, VectorType lhsTileTy, } // Similar to multiplyBlocksPreloadLhs but here RHS is preloaded to tiles. -void multiplyBlocksPreloadRhs(Location loc, VectorType lhsTileTy, - VectorType rhsTileTy, VectorType accTileTy, +void multiplyBlocksPreloadRhs(Location loc, amx::TileType lhsTileTy, + amx::TileType rhsTileTy, amx::TileType accTileTy, const AmxBuffer &lhsBuf, const AmxBuffer &rhsBuf, const AmxBuffer &accBuf, int64_t blockM, int64_t blockN, int64_t blockK, @@ -667,15 +667,15 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, VectorType rhsTy = cast(op.getB().getType()); VectorType accTy = cast(op.getC().getType()); VectorType resTy = cast(op.getResult().getType()); - VectorType lhsTileTy = - lhsTy.cloneWith(SmallVector({candidate.tileM, candidate.tileK}), - candidate.lhsTileElemTy); - VectorType rhsTileTy = getSwizzledRhsTileType( - rhsTy.cloneWith(SmallVector({candidate.tileK, candidate.tileN}), - candidate.rhsTileElemTy)); - VectorType accTileTy = - accTy.cloneWith(SmallVector({candidate.tileM, candidate.tileN}), - candidate.accTileElemTy); + amx::TileType lhsTileTy = amx::TileType::get( + SmallVector({candidate.tileM, candidate.tileK}), + candidate.lhsTileElemTy); + amx::TileType rhsTileTy = getSwizzledRhsTileType(amx::TileType::get( + SmallVector({candidate.tileK, candidate.tileN}), + candidate.rhsTileElemTy)); + amx::TileType accTileTy = amx::TileType::get( + SmallVector({candidate.tileM, candidate.tileN}), + candidate.accTileElemTy); // If we don't work with a loop and want to directly store tiles into output // memory, then use the original store as insertion point to have its buffer From fcada66fdb874ce33f28386fd5b66c1974362f42 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 6 Dec 2024 20:16:19 +0000 Subject: [PATCH 147/165] Fix formatting --- bin/RegisterTritonDialects.h | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 85f17c611b3e..025e229962b5 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -81,15 +81,16 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::cpu::registerTritonOpScalarizeExternalModels(registry); // TODO: register Triton & TritonGPU passes - registry.insert(); + registry + .insert(); } From f2d3208e9488957087c367eeef1e0429674ef055 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 6 Dec 2024 20:16:37 +0000 Subject: [PATCH 148/165] Fix test_tl_range. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_core.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e91e6d90dc82..e4e0b90819f1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -6753,10 +6753,9 @@ def test_tl_range_num_stages(device): a = torch.randn((M, K), device=device, dtype=torch.float16) b = torch.randn((K, N), device=device, dtype=torch.float16) c = torch.empty((M, N), dtype=torch.float32, device=device) - pgm = matmul_kernel[ - 1, - ](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, BLOCK_N, - BLOCK_K, 0, num_stages=5) + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + pgm = matmul_kernel[grid](a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), + c.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, 0, num_stages=5) if is_cpu(): # torch.matmul not implemented for Half float (float16) cpu ref_out = torch.tensor(np.matmul(to_numpy(a), to_numpy(b)), dtype=torch.float32, device=device) From 4361d3463f666e830f4a9c34835c087135b661da Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 6 Dec 2024 20:18:10 +0000 Subject: [PATCH 149/165] Fix test_conversions. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_conversions.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/test/unit/language/test_conversions.py b/python/test/unit/language/test_conversions.py index 14c46000c1ef..2cacd8c99991 100644 --- a/python/test/unit/language/test_conversions.py +++ b/python/test/unit/language/test_conversions.py @@ -7,12 +7,9 @@ import triton import triton.language as tl -from triton._internal_testing import is_cuda, is_hip, is_hip_mi300 +from triton._internal_testing import is_cuda, is_hip, is_hip_mi300, is_cpu -def is_cpu(): - return not is_interpreter() and triton.runtime.driver.active.get_current_target().backend == "cpu" - def matching_int(dtype): if dtype.primitive_bitwidth == 8: return torch.int8 From dccfd798c44cd6dd9b161826df10e743c1eb6c9d Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Fri, 6 Dec 2024 23:54:11 +0000 Subject: [PATCH 150/165] Disable test_block_copy with lower bound check. Signed-off-by: Ilya Enkovich --- python/test/unit/language/test_block_pointer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index fb2002101bb3..f9591b6505e2 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -37,6 +37,9 @@ def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, PADDING_OPTION: for boundary_check in (None, "lower", "upper") ]) def test_block_copy(dtypes_str, n, padding_option, boundary_check, device): + if is_cpu() and boundary_check == "lower": + pytest.skip("Lower boundary check is NYI for CPU") + src_dtype_str = dtypes_str[0] dst_dtype_str = dtypes_str[1] src_dtype = getattr(torch, src_dtype_str) From a509dd944c25dff6fdbc65a25112709bfce0afc7 Mon Sep 17 00:00:00 2001 From: Minjang Kim Date: Mon, 9 Dec 2024 15:11:48 -0800 Subject: [PATCH 151/165] Fix isSigned and add float16 in PrintOp (#191) * Fix isSigned in PrintOp * Add float16 support for print * Support float16 printing for old compilers --- .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 6 +- .../lib/TritonToTritonCPU/ConvertDebugOps.cpp | 8 ++- third_party/cpu/runtime/cpu_runtime.cpp | 62 +++++++++++++++++-- 3 files changed, 66 insertions(+), 10 deletions(-) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index a6d1487ebcc7..21b4756b506e 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -190,7 +190,7 @@ void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, void createRuntimePrintCall(ConversionPatternRewriter &rewriter, std::array pid, StringRef prefix, - Value ptr, Type dtype, bool hex) { + Value ptr, Type dtype, bool isSigned, bool hex) { assert(!prefix.empty()); auto loc = UnknownLoc::get(rewriter.getContext()); Value prefixValue = LLVM::addStringToModule( @@ -205,7 +205,7 @@ void createRuntimePrintCall(ConversionPatternRewriter &rewriter, allArgs.push_back(i32_val(dtype.getIntOrFloatBitWidth())); allArgs.push_back(i32_val(dtype.isInteger())); - allArgs.push_back(i32_val(dtype.isSignedInteger())); + allArgs.push_back(i32_val(isSigned)); allArgs.push_back(i32_val(hex)); call(getOrAddPrintMemrefFuncDecl(rewriter), allArgs); @@ -254,7 +254,7 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { createRuntimePrintCall( rewriter, pid, op.getPrefix(), adaptor.getOperands()[0], cast(op.getVal()[0].getType()).getElementType(), - op.getHex()); + op.getIsSigned()[0], op.getHex()); rewriter.eraseOp(op); return success(); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp index 83fe858fb139..80edcf69f239 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertDebugOps.cpp @@ -61,11 +61,13 @@ struct PrintOpConversion : public OpConversionPattern { return success(); } - for (auto operand : op.getOperands()) { + for (size_t i = 0; i < op.getNumOperands(); i++) { + Value operand = op.getOperands()[i]; + auto isSigned = {op.getIsSigned()[i]}; if (!isa(operand.getType())) { rewriter.create( loc, op.getPrefix(), op.getHex(), - rewriter.getRemappedValue(operand), false); + rewriter.getRemappedValue(operand), isSigned); continue; } @@ -92,7 +94,7 @@ struct PrintOpConversion : public OpConversionPattern { allocVal); rewriter.create(loc, op.getPrefix(), op.getHex(), - allocUnrankedVal, false); + allocUnrankedVal, isSigned); rewriter.create(loc, allocVal); } diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index 537441903212..68b7efa78f01 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -9,6 +9,9 @@ #include #include +#define __STDC_WANT_IEC_60559_TYPES_EXT__ +#include + #if defined(_MSC_VER) #define EXPORT __declspec(dllexport) #elif defined(__GNUC__) @@ -24,6 +27,42 @@ const int MAX_FLOAT_WIDTH = 8; const int FLOAT_PREC = 4; const int ELEMS_PER_LINE = 8; +using FLOAT16 = struct _FLOAT16 { +#ifdef FLT16_MAX + _Float16 x; +#else + uint16_t x; +#endif + + float toFloat32() const { +#ifdef FLT16_MAX + return static_cast(x); +#else + // Based on https://gist.github.com/zhuker/b4bd1fb306c7b04975b712c37c4c4075 + uint32_t t1; + uint32_t t2; + uint32_t t3; + + t1 = x & 0x7fffu; // Non-sign bits + t2 = x & 0x8000u; // Sign bit + t3 = x & 0x7c00u; // Exponent + + t1 <<= 13u; // Align mantissa on MSB + t2 <<= 16u; // Shift sign bit into position + + t1 += 0x38000000; // Adjust bias + + t1 = (t3 == 0 ? 0 : t1); // Denormals-as-zero + + t1 |= t2; // Re-insert sign bit + + float out; + *((uint32_t *)&out) = t1; + return out; +#endif + } +}; + struct FormatInfo { bool isInt; bool isSigned; @@ -91,6 +130,12 @@ std::pair computeDigitInfo(T val) { return {digits, val < 0}; } +template <> +std::pair +computeDigitInfo(FLOAT16 val) { + return computeDigitInfo(val.toFloat32()); +} + template std::tuple computeDigitStats(const MemRefDescriptor &desc) { int maxIntDigits = 0; @@ -177,6 +222,12 @@ void printFormattedElement(std::stringstream &ss, uint8_t val, printFormattedElement(ss, val, formatInfo); } +template <> +void printFormattedElement(std::stringstream &ss, FLOAT16 val, + const FormatInfo &formatInfo) { + printFormattedElement(ss, val.toFloat32(), formatInfo); +} + template void printToStreamRecursive(const MemRefDescriptor &desc, std::stringstream &ss, const FormatInfo &formatInfo, @@ -247,6 +298,10 @@ void printMemRef(std::stringstream &ss, int32_t rank, void *descriptor, printToStream(MemRefDescriptor(rank, descriptor), ss, partialFormat, linePrefix); return; + case 16: + printToStream(MemRefDescriptor(rank, descriptor), ss, + partialFormat, linePrefix); + return; default: llvm_unreachable("Unsupported bitWidth"); } @@ -325,14 +380,13 @@ EXPORT void triton_assert(int32_t pid0, int32_t pid1, int32_t pid2, bool cond, EXPORT void triton_print_unranked_memref(int32_t pid0, int32_t pid1, int32_t pid2, const char *prefix, UnrankedMemRefType memref, int32_t btw, - bool isInteger, bool isSignedInteger, + bool isInteger, bool isSigned, bool asHex) { std::stringstream ss; ss << "(" << pid0 << ", " << pid1 << ", " << pid2 << ")" << prefix; std::string linePrefix(ss.str().size(), ' '); - - printMemRef(ss, memref.rank, memref.descriptor, btw, isInteger, - isSignedInteger, asHex, linePrefix); + printMemRef(ss, memref.rank, memref.descriptor, btw, isInteger, isSigned, + asHex, linePrefix); ss << "\n"; std::cout << ss.str() << std::flush; } From df3843047068ac21d4d155272ca226dc66cbb604 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 10 Dec 2024 14:03:43 -0600 Subject: [PATCH 152/165] Add TritonCPU canonicalizer. (#192) Signed-off-by: Ilya Enkovich --- test/TritonCPU/canonicalize.mlir | 30 +++++ third_party/cpu/backend/compiler.py | 1 + .../cpu/include/TritonCPUTransforms/Passes.h | 1 + .../cpu/include/TritonCPUTransforms/Passes.td | 13 +++ .../lib/TritonCPUTransforms/CMakeLists.txt | 1 + .../lib/TritonCPUTransforms/Canonicalize.cpp | 110 ++++++++++++++++++ third_party/cpu/triton_cpu.cc | 3 + 7 files changed, 159 insertions(+) create mode 100644 test/TritonCPU/canonicalize.mlir create mode 100644 third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp diff --git a/test/TritonCPU/canonicalize.mlir b/test/TritonCPU/canonicalize.mlir new file mode 100644 index 000000000000..9e14645861bb --- /dev/null +++ b/test/TritonCPU/canonicalize.mlir @@ -0,0 +1,30 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-canonicalize | FileCheck %s + +// Fold transfer read and shape cast. + +// CHECK-LABEL: @fold_transfer_read_shape_cast +// CHECK: %[[VAL:.+]] = vector.transfer_read +// CHECK: vector.transfer_write %[[VAL]] + +module { + tt.func public @fold_transfer_read_shape_cast(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + %cst = arith.constant 0.000000e+00 : bf16 + %c0 = arith.constant 0 : index + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %c2_i64 = arith.constant 2 : i64 + %c16_i64 = arith.constant 16 : i64 + %c256_i64 = arith.constant 256 : i64 + %c512_i64 = arith.constant 512 : i64 + %in_p = tt.make_tensor_ptr %arg0, [%c2_i64, %c2_i64, %c16_i64, %c16_i64], [%c512_i64, %c256_i64, %c16_i64, %c1_i64], [%c0_i32, %c0_i32, %c0_i32, %c0_i32] {order = array} : > + %out_p = tt.make_tensor_ptr %arg1, [%c16_i64, %c16_i64], [%c16_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %memref1 = triton_cpu.extract_memref %in_p : > -> memref<2x2x16x16xbf16, strided<[512, 256, 16, 1]>> + %indices1:4 = triton_cpu.extract_indices %in_p : > -> index, index, index, index + %val1 = vector.transfer_read %memref1[%indices1#0, %indices1#1, %indices1#2, %indices1#3], %cst {in_bounds = [true, true, true, true]} : memref<2x2x16x16xbf16, strided<[512, 256, 16, 1]>>, vector<1x1x16x16xbf16> + %val2 = vector.shape_cast %val1 : vector<1x1x16x16xbf16> to vector<16x16xbf16> + %memref2 = triton_cpu.extract_memref %out_p : > -> memref<16x16xbf16, strided<[16, 1]>> + %indices2:2 = triton_cpu.extract_indices %out_p : > -> index, index + vector.transfer_write %val2, %memref2[%indices2#0, %indices2#1] {in_bounds = [true, true]} : vector<16x16xbf16>, memref<16x16xbf16, strided<[16, 1]>> + tt.return + } +} diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 5395afe282f8..3670ee298e15 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -160,6 +160,7 @@ def make_tttcir(self, mod, metadata, opt): # TTCIR -> Target TTCIR pm = ir.pass_manager(mod.context) pm.enable_debug() + cpu.passes.ttcpuir.add_triton_cpu_canonicalizer(pm) cpu.passes.ttcpuir.add_optimize_masks(pm) passes.common.add_canonicalizer(pm) convert_bf16_dot_product = ((self.cpu_arch == "aarch64" or self.cpu_arch == "armv8") diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index 9b7402d7f0f8..7d3ecaf5c515 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -37,6 +37,7 @@ std::unique_ptr> createConvertDotToAMX(); std::unique_ptr> createConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16); std::unique_ptr> createConvertDotGeneric(); +std::unique_ptr> createCanonicalize(); #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUTransforms/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index a38673595538..42df78e0bb12 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -137,4 +137,17 @@ def ConvertDotGeneric : Pass<"triton-cpu-convert-dot-generic", "mlir::ModuleOp"> "mlir::triton::cpu::TritonCPUDialect"]; } +def Canonicalize : Pass<"triton-cpu-canonicalize", "mlir::ModuleOp"> { + let summary = "Canonicalization pass."; + let description = [{ + This pass applies various foldings to simplify analysis and transformations + in optimization passes. + }]; + + let constructor = "mlir::triton::cpu::createCanonicalize()"; + + let dependentDialects = ["mlir::vector::VectorDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt index c421e35f8797..17d2560e6187 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonCPUTransforms ConvertDotOp/ConvertDotGeneric.cpp ConvertDotOp/ConvertDotToAMX.cpp + Canonicalize.cpp ConvertDotProduct.cpp ConvertUnsupportedOps.cpp DecomposeFpConversions.cpp diff --git a/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp b/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp new file mode 100644 index 000000000000..65fed92d2b50 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp @@ -0,0 +1,110 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CANONICALIZE +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// Fold transfer read and the following shape cast that removes heading +// dimensions with size 1. +struct FoldReadShapeCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + if (!op->hasOneUse()) + return failure(); + + auto permMap = op.getPermutationMap(); + if (!permMap.isMinorIdentity()) + return failure(); + + auto reshape = dyn_cast(*op->user_begin()); + if (!reshape) + return failure(); + + VectorType ty = cast(op.getType()); + VectorType dstTy = cast(reshape.getType()); + if (ty.getRank() <= dstTy.getRank()) + return failure(); + + // Check all removed dimensions have size 1. + if (!all_of(drop_end(ty.getShape(), dstTy.getRank()), + [](int64_t val) { return val == 1; })) + return failure(); + + // Check shape prefix matches the resulting type. + if (!equal(drop_begin(ty.getShape(), ty.getRank() - dstTy.getRank()), + dstTy.getShape())) + return failure(); + + auto inBounds = op.getInBounds(); + if (std::any_of(inBounds.begin(), inBounds.end() - dstTy.getRank(), + [](Attribute attr) { + return !cast(attr).getValue(); + })) + return failure(); + + // Fold read and shape cast into a single read. + auto newPermMap = permMap.getMinorIdentityMap( + permMap.getNumDims(), dstTy.getRank(), getContext()); + auto newInBounds = rewriter.getArrayAttr(SmallVector(drop_begin( + op.getInBounds().getValue(), ty.getRank() - dstTy.getRank()))); + auto newRead = rewriter.create( + loc, dstTy, op.getSource(), op.getIndices(), newPermMap, + op.getPadding(), op.getMask(), newInBounds); + rewriter.replaceOp(reshape, newRead); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct Canonicalize : public triton::cpu::impl::CanonicalizeBase { + Canonicalize() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + RewritePatternSet patterns(context); + patterns.add(context); + + if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createCanonicalize() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 40f56204c427..8206b021f357 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -76,6 +76,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_convert_debug_ops", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertDebugOps()); }); + m.def("add_triton_cpu_canonicalizer", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createCanonicalize()); + }); m.def("add_optimize_masks", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createOptimizeMasks()); }); From 90908d1402decc8a645db94bf7ca7e8aaabcbd30 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 11 Dec 2024 18:46:15 -0600 Subject: [PATCH 153/165] Introduce FMA lowering for DotOp. (#193) * Add pass to decompose matmul to FMA operations. Signed-off-by: Ilya Enkovich * Use block pointers and padding in 03-matrix-multiplication-cpu.py. * Fix review comments. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- .../tutorials/03-matrix-multiplication-cpu.py | 152 ++++-- python/tutorials/cpu-blocked-matmul-fp32.py | 373 ++++++++++++++ third_party/cpu/backend/compiler.py | 2 + .../cpu/include/TritonCPUTransforms/Passes.h | 1 + .../cpu/include/TritonCPUTransforms/Passes.td | 12 + .../lib/TritonCPUTransforms/CMakeLists.txt | 2 + .../ConvertDotOp/ConvertDotCommon.cpp | 190 +++++++ .../ConvertDotOp/ConvertDotCommon.h | 72 +++ .../ConvertDotOp/ConvertDotToFMA.cpp | 462 ++++++++++++++++++ third_party/cpu/triton_cpu.cc | 3 + 10 files changed, 1228 insertions(+), 41 deletions(-) create mode 100644 python/tutorials/cpu-blocked-matmul-fp32.py create mode 100644 third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp create mode 100644 third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h create mode 100644 third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index c61a8098eac4..f2ee03dfadc2 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -153,14 +153,36 @@ import triton import triton.language as tl +import os -BLOCK_SIZE_M = 32 +DTYPE = getattr(torch, (os.getenv("DTYPE", "float32"))) +# Chosse block size depending on dtype. We have more register +# capacity for bfloat16/float16 compared to float32. +BLOCK_SIZE_M = 8 if DTYPE == torch.float32 else 32 BLOCK_SIZE_N = 32 -BLOCK_SIZE_K = 32 +BLOCK_SIZE_K = 8 if DTYPE == torch.float32 else 32 +CACHE_PADDING = os.getenv("CACHE_PADDING", "0") != "0" +PREPACKED = os.getenv("PREPACKED", "0") != "0" +PAD_B_ONLY = True +USE_BLOCK_POINTERS = os.getenv("USE_BLOCK_POINTERS", "1") != "0" GROUP_SIZE_M = 8 USE_GPU = False +@triton.jit +def pad_kernel(in_ptr, out_ptr, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, PADDING: tl.constexpr): + in_offset = tl.program_id(axis=0) * N * BLOCK_SIZE_M + out_offset = tl.program_id(axis=0) * (N + PADDING) * BLOCK_SIZE_M + for row in tl.range(0, BLOCK_SIZE_M): + for block in tl.range(0, N // BLOCK_SIZE_N): + val = tl.load(in_ptr + in_offset + block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + tl.store(out_ptr + out_offset + block * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N), val) + zero = tl.full((PADDING, ), 0, dtype=in_ptr.type.element_ty) + tl.store(out_ptr + out_offset + N + tl.arange(0, PADDING), zero) + in_offset += N + out_offset += N + PADDING + + @triton.jit def matmul_kernel( # Pointers to matrices @@ -176,6 +198,7 @@ def matmul_kernel( # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # + USE_BLOCK_POINTERS: tl.constexpr, # ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -198,14 +221,21 @@ def matmul_kernel( # Create pointers for the first blocks of A and B. # We will advance this pointer as we move in the K direction # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - # See above `Pointer Arithmetic` section for details - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + if USE_BLOCK_POINTERS: + block_offset_m = pid_m * BLOCK_SIZE_M + block_offset_n = pid_n * BLOCK_SIZE_N + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), + order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(1, 0)) + else: + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_tile_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_tile_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -217,43 +247,60 @@ def matmul_kernel( # Load the next block of A and B, generate a mask by checking the K dimension. # If it is out of bounds, set it to 0. - # TODO: Currently masked load is not supported yet. - # a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) + a = tl.load(a_tile_ptr) + b = tl.load(b_tile_ptr) # We accumulate along the K dimension. accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk + if USE_BLOCK_POINTERS: + a_tile_ptr = tl.advance(a_tile_ptr, [0, BLOCK_SIZE_K]) + b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0]) + else: + a_tile_ptr += BLOCK_SIZE_K * stride_ak + b_tile_ptr += BLOCK_SIZE_K * stride_bk # Convert the accumulator to the output matrix C's type if needed. c = accumulator # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - - # TODO: Currently masked load is not supported yet. - # c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - # tl.store(c_ptrs, c, mask=c_mask) - tl.store(c_ptrs, c) + # Write back the block of the output matrix C. + if USE_BLOCK_POINTERS: + c_tile_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), + offsets=(block_offset_m, block_offset_n), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) + tl.store(c_tile_ptr, c) + else: + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_tile_ptr = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_tile_ptr, c) # %% # We can now create a convenience wrapper function that only takes two input tensors, # and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. +a_scratch = torch.empty((), dtype=DTYPE) +b_scratch = torch.empty((), dtype=DTYPE) -def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): + +def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.is_contiguous(), "Matrix A must be contiguous" M, K = a.shape K, N = b.shape + + # TODO: Check if padding is needed at all. + if CACHE_PADDING: + a_scratch.resize_(M, K + 32) + b_scratch.resize_(K, N + 32) + if not PAD_B_ONLY: + pad_kernel[(M // BLOCK_SIZE_M, )](a, a_scratch, K, BLOCK_SIZE_M, BLOCK_SIZE_K, 32, num_threads=num_threads) + a = a_scratch + pad_kernel[(K // BLOCK_SIZE_K, )](b, b_scratch, N, BLOCK_SIZE_K, BLOCK_SIZE_N, 32, num_threads=num_threads) + b = b_scratch + #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" @@ -262,6 +309,14 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): c = torch.empty((M, N), device=a.device, dtype=a.dtype) else: assert c.shape == (M, N), "Incompatible dimensions" + + return a, b, c + + +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, M: int, N: int, K: int, num_threads=0): + if not PREPACKED: + a, b, c = matmul_preprocess_input(a, b, c, num_threads=num_threads) + # 1D launch kernel where each block gets its own program. grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), ) matmul_kernel[grid]( @@ -272,6 +327,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): c.stride(0), c.stride(1), # BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # + USE_BLOCK_POINTERS=USE_BLOCK_POINTERS, # num_threads=num_threads, # ) return c @@ -287,10 +343,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): triton.runtime.driver.set_active_to_cpu() -a = torch.randn((512, 512), device='cpu', dtype=torch.float32) -b = torch.randn((512, 512), device='cpu', dtype=torch.float32) -triton_output = matmul(a, b, None) -torch_output = torch.matmul(a, b) +a = torch.randn((512, 512), device='cpu', dtype=DTYPE) +b = torch.randn((512, 512), device='cpu', dtype=DTYPE) +c = None +torch_output = torch.matmul(a.to(torch.float32), b.to(torch.float32)) +if PREPACKED: + a, b, c = matmul_preprocess_input(a, b, c) +triton_output = matmul(a, b, c, 512, 512, 512) print(f"triton_cpu_output_with_{a.dtype}_inputs={triton_output}") print(f"torch_cpu_output_with_{a.dtype}_inputs={torch_output}") rtol = 0 @@ -310,9 +369,9 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): # We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices, # but feel free to arrange this script as you wish to benchmark any other matrix shape. -LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native', 'torch-cpu-compile'] -LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)', 'TorchCPU (compile)'] -LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--'), ('green', '-')] +LINE_VALS = ['triton-cpu-single', 'triton-cpu', 'torch-cpu-native'] +LINE_NAMES = ['TritonCPU 1', 'TritonCPU', 'TorchCPU (native)'] +LINE_STYLES = [('blue', '--'), ('blue', '-'), ('green', '--')] if USE_GPU and triton.runtime.driver.get_active_gpus(): triton.runtime.driver.set_active_to_gpu() @@ -356,36 +415,47 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, num_threads=0): ylabel='GFLOPS', # Label name for the y-axis. plot_name= # Name for the plot. Used also as a file name for saving the plot. - f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', + f'matmul-performance-{DTYPE} (USE_BLOCK_POINTERS={USE_BLOCK_POINTERS} CACHE_PADDING={CACHE_PADDING} PREPACKED={PREPACKED} PAD_B_ONLY={PAD_B_ONLY} GROUP_SIZE_M={GROUP_SIZE_M})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(M, N, K, provider): device = 'cpu' if 'cpu' in provider else 'cuda' - a = torch.randn((M, K), device=device, dtype=torch.float32) - b = torch.randn((K, N), device=device, dtype=torch.float32) + a = torch.randn((M, K), device=device, dtype=DTYPE) + b = torch.randn((K, N), device=device, dtype=DTYPE) if device == 'cpu': - c = torch.empty((M, N), device=a.device, dtype=a.dtype) + if 'triton-cpu' in provider: + c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + else: + c = torch.zeros((M, N), device=a.device, dtype=a.dtype) triton.runtime.driver.set_active_to_cpu() else: c = None triton.runtime.driver.set_active_to_gpu() + if PREPACKED: + triton_a, triton_b, triton_c = matmul_preprocess_input(a, b, c) + else: + triton_a, triton_b, triton_c = a, b, c + quantiles = [0.5, 0.2, 0.8] if provider == 'torch-gpu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles) elif provider == 'triton-gpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, None), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(triton_a, triton_b, None), quantiles=quantiles) elif provider == 'torch-cpu-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles) elif provider == 'torch-cpu-compile': compiled = torch.compile(torch.matmul) ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles) elif provider == 'triton-cpu-single': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c, num_threads=1), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul(triton_a, triton_b, triton_c, M, N, K, num_threads=1), quantiles=quantiles, + measure_time_with_hooks=True) elif provider == 'triton-cpu': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, c), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(triton_a, triton_b, triton_c, M, N, K), + quantiles=quantiles, measure_time_with_hooks=True) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/cpu-blocked-matmul-fp32.py b/python/tutorials/cpu-blocked-matmul-fp32.py new file mode 100644 index 000000000000..0df8f41b9b2b --- /dev/null +++ b/python/tutorials/cpu-blocked-matmul-fp32.py @@ -0,0 +1,373 @@ +""" +Matrix Multiplication +===================== +In this tutorial, matmul on CPU with different input layouts is tested. + +This tutorial is optimized for AMX-enabled CPUs. + +""" + +# %% +# Kernels +# ------- + +import torch + +import triton +import triton.language as tl + +BLOCK_SIZE_M = 8 +BLOCK_SIZE_N = 32 +BLOCK_SIZE_K = 8 +GROUP_SIZE_M = 8 + + +# This kernel is used for blocked encoding of input tensors for matmul. +# +# Blocked encoding is used to transform 2D tensor [M, N] into 4D tensor +# [M / BLOCK_SIZE_M, N / BLOCK_SIZE_N, BLOCK_SIZE_M, BLOCK_SIZE_N]. +# This makes following access to blocks in matmul more efficient because +# each block is placed into a contiguous memory fragment and is likely +# to fit a single memory page. +# +# If TRANSPOSED_B is set to True then head dimensions of the RHS +# tensor are transposed. It provides contiguos placement for a column +# of blocks. +# +# If TRANSPOSED_BLOCK_A is set to True then tail dimensions of the LHS +# tensor are transposed. Transposed LHS block better matches FMA lowering +# used by Triton CPU backend which processes RHS block row-by-row and LHS +# block column-by-column. +@triton.jit +def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, + TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr, + TRANSPOSED_B: tl.constexpr): + tl.static_assert(M % BLOCK_SIZE_M == 0) + tl.static_assert(N % BLOCK_SIZE_N == 0) + tl.static_assert(BLOCKED_A or not TRANSPOSED_BLOCK_A) + tl.static_assert(BLOCKED_B or not TRANSPOSED_B) + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + in_block_m = first_pid_m + (pid % group_size_m) + in_block_n = (pid % num_pid_in_group) // group_size_m + + if BLOCKED_A: + a_out_block_m = in_block_m + A_OUT_BLOCK_SIZE_M: tl.constexpr = BLOCK_SIZE_K if TRANSPOSED_BLOCK_A else BLOCK_SIZE_M + A_OUT_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_M if TRANSPOSED_BLOCK_A else BLOCK_SIZE_K + A_OUT_BLOCKS_M: tl.constexpr = M // BLOCK_SIZE_M + A_OUT_BLOCKS_K: tl.constexpr = K // BLOCK_SIZE_K + A_OUT_STRIDE_M: tl.constexpr = A_OUT_BLOCK_SIZE_K + A_OUT_STRIDE_BLOCK_M: tl.constexpr = BLOCK_SIZE_M * K + A_OUT_STRIDE_BLOCK_K: tl.constexpr = BLOCK_SIZE_M * BLOCK_SIZE_K + for in_block_k in tl.range(in_block_n, A_OUT_BLOCKS_K, N // BLOCK_SIZE_N): + a_out_block_k = in_block_k + a_in_ptr = tl.make_block_ptr(base=in_a, shape=(M, K), strides=(K, 1), + offsets=(in_block_m * BLOCK_SIZE_M, in_block_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0)) + a_out_ptr = tl.make_block_ptr( + base=out_a, shape=(A_OUT_BLOCKS_M, A_OUT_BLOCKS_K, A_OUT_BLOCK_SIZE_M, A_OUT_BLOCK_SIZE_K), + strides=(A_OUT_STRIDE_BLOCK_M, A_OUT_STRIDE_BLOCK_K, A_OUT_STRIDE_M, 1), + offsets=(a_out_block_m, a_out_block_k, 0, 0), + block_shape=(1, 1, A_OUT_BLOCK_SIZE_M, A_OUT_BLOCK_SIZE_K), order=(3, 2, 1, 0)) + val = tl.load(a_in_ptr) + if TRANSPOSED_BLOCK_A: + val = val.T + val = tl.reshape(val, (1, 1, A_OUT_BLOCK_SIZE_M, A_OUT_BLOCK_SIZE_K)) + tl.store(a_out_ptr, val) + + if BLOCKED_B: + B_OUT_BLOCKS_K: tl.constexpr = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K + B_OUT_BLOCKS_N: tl.constexpr = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N + B_OUT_STRIDE_K: tl.constexpr = BLOCK_SIZE_N + B_OUT_STRIDE_BLOCK_K: tl.constexpr = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N) + B_OUT_STRIDE_BLOCK_N: tl.constexpr = BLOCK_SIZE_K * BLOCK_SIZE_N + for in_block_k in tl.range(in_block_m, K // BLOCK_SIZE_K, M // BLOCK_SIZE_M): + b_out_block_k = in_block_n if TRANSPOSED_B else in_block_k + b_out_block_n = in_block_k if TRANSPOSED_B else in_block_n + b_in_ptr = tl.make_block_ptr(base=in_b, shape=(K, N), strides=(N, 1), + offsets=(in_block_k * BLOCK_SIZE_K, in_block_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0)) + b_out_ptr = tl.make_block_ptr(base=out_b, + shape=(B_OUT_BLOCKS_K, B_OUT_BLOCKS_N, BLOCK_SIZE_K, BLOCK_SIZE_N), + strides=(B_OUT_STRIDE_BLOCK_K, B_OUT_STRIDE_BLOCK_N, B_OUT_STRIDE_K, 1), + offsets=(b_out_block_k, b_out_block_n, 0, 0), + block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), order=(3, 2, 1, 0)) + val = tl.load(b_in_ptr) + val = tl.reshape(val, (1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N)) + tl.store(b_out_ptr, val) + + +# Matmul kernel that computes a single output block [BLOCK_SIZE_M, BLOCK_SIZE_N]. LHS can be in the +# rowmajor, blocked, or blocked transposed encoding. RHS can be in rowmajor, blocked, or transposed +# blocked encoding. +# +# To cover all input layouts, we use 4D block pointers that address a single input block +# [1, 1, BLOCK_SIZE_M, BLOCK_SIZE_N], we choose strides for these block pointers +# appropriately to keep navigation bentween blocks similar for all input encodings. +# +# E.g. for rowmajor LHS we use BLOCK_SIZE_K stride to move to the next block over K axis, but +# for blocked encoding we use BLOCK_SIZE_M * BLOCK_SIZE_K stride. In both cases we then can +# advance using the same (0, 1, 0, 0) offset in the loop. +# +# Reshape is used to remove the heading (1, 1) dimensions, but CPU backend folds it with the load +# operation and it doesn't prevent direct vector loads from the input memory. +@triton.jit +def matmul_kernel_fma(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + # number of blocks in a group + GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, + BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): + # TRANSPOSED_BLOCK_A means that each block in A is transposed. + # It is allowed only for blocked input. + assert (BLOCKED_A or not TRANSPOSED_BLOCK_A) + # TRANSPOSED_B means that blocks of B are reordered but blocks + # itself are not transpoed. It is allowed only for blocked input. + assert (BLOCKED_B or not TRANSPOSED_B) + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + block_m = first_pid_m + (pid % group_size_m) + block_n = (pid % num_pid_in_group) // group_size_m + + A_BLOCK_SIZE_M: tl.constexpr = BLOCK_SIZE_K if TRANSPOSED_BLOCK_A else BLOCK_SIZE_M + A_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_M if TRANSPOSED_BLOCK_A else BLOCK_SIZE_K + A_BLOCKS_M = M // BLOCK_SIZE_M + A_BLOCKS_K = K // BLOCK_SIZE_K + a_stride_k = 1 + a_stride_m = A_BLOCK_SIZE_K if BLOCKED_A else K + a_stride_block_k = A_BLOCK_SIZE_M * A_BLOCK_SIZE_K if BLOCKED_A else A_BLOCK_SIZE_K + a_stride_block_m = BLOCK_SIZE_M * K + + b_stride_n = 1 + b_stride_k = BLOCK_SIZE_N if BLOCKED_B else N + if TRANSPOSED_B: + b_stride_block_n = BLOCK_SIZE_N * K + b_stride_block_k = BLOCK_SIZE_K * BLOCK_SIZE_N + else: + b_stride_block_n = BLOCK_SIZE_K * BLOCK_SIZE_N if BLOCKED_B else BLOCK_SIZE_N + b_stride_block_k = BLOCK_SIZE_K * N + + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(A_BLOCKS_M, A_BLOCKS_K, A_BLOCK_SIZE_M, A_BLOCK_SIZE_K), + strides=(a_stride_block_m, a_stride_block_k, a_stride_m, a_stride_k), + offsets=(block_m, 0, 0, 0), block_shape=(1, 1, A_BLOCK_SIZE_M, A_BLOCK_SIZE_K), + order=(3, 2, 1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, + shape=(K // BLOCK_SIZE_K, N // BLOCK_SIZE_N, BLOCK_SIZE_K, BLOCK_SIZE_N), + strides=(b_stride_block_k, b_stride_block_n, b_stride_k, b_stride_n), + offsets=(0, block_n, 0, 0), block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(3, 2, 1, 0)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(N, 1), + offsets=(block_m * BLOCK_SIZE_M, block_n * BLOCK_SIZE_N), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) + + c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_block_ptr).reshape((A_BLOCK_SIZE_M, A_BLOCK_SIZE_K)) + b = tl.load(b_block_ptr).reshape((BLOCK_SIZE_K, BLOCK_SIZE_N)) + + if TRANSPOSED_BLOCK_A: + a = a.T + + c += tl.dot(a, b, out_dtype=tl.float32) + + a_block_ptr = tl.advance(a_block_ptr, (0, 1, 0, 0)) + b_block_ptr = tl.advance(b_block_ptr, (1, 0, 0, 0)) + + tl.store(c_block_ptr, c) + + +def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, + PREPACKED, BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0): + #TODO: Currently masked load is not supported yet. + assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( + K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" + # 1D launch kernel where each block gets its own program. + grid = ((M // BLOCK_SIZE_M) * (N // BLOCK_SIZE_N), ) + if (BLOCKED_A or BLOCKED_B) and not PREPACKED: + block_transpose_combined_kernel[grid]( + a, ab, b, bb, # + M, N, K, # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # + GROUP_SIZE_M=GROUP_SIZE_M, # + BLOCKED_A=BLOCKED_A, TRANSPOSED_BLOCK_A=TRANSPOSED_BLOCK_A, # + BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B) + if BLOCKED_A: + a = ab + if BLOCKED_B: + b = bb + matmul_kernel_fma[grid]( + a, b, c, # + M, N, K, # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # + GROUP_SIZE_M=GROUP_SIZE_M, # + BLOCKED_A=BLOCKED_A, TRANSPOSED_BLOCK_A=TRANSPOSED_BLOCK_A, # + BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B, num_threads=num_threads) + return c + + +# %% +# Unit Test +# --------- +# +# We can test our custom matrix multiplication operation against a native torch implementation. +torch.manual_seed(0) + +triton.runtime.driver.set_active_to_cpu() + +a = torch.randn((512, 512), device='cpu', dtype=torch.float32) +b = torch.randn((512, 512), device='cpu', dtype=torch.float32) +c = torch.empty((512, 512), device='cpu', dtype=torch.float32) +torch_output = torch.matmul(a, b) +rtol = 0 +a_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_M) * (512 // BLOCK_SIZE_K) * 64), device='cpu', dtype=torch.float32) +b_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_K) * (512 // BLOCK_SIZE_N) * 64), device='cpu', dtype=torch.float32) +triton_output = matmul_fma(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False) +if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ TritonCPU and TorchCPU match") +else: + print("❌ TritonCPU and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + assert False +triton_output = matmul_fma(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True) +if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ TritonCPU pre-packed and TorchCPU match") +else: + print("❌ TritonCPU pre-packed and TorchCPU differ, the maximum difference is " + f'{torch.max(torch.abs(triton_output - torch_output))}') + assert False + +# %% +# Benchmark +# --------- +# +# Square Matrix Performance +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We can now compare the performance of our kernel against that of Pytorch. Here we focus on square matrices, +# but feel free to arrange this script as you wish to benchmark any other matrix shape. + + +def encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype): + assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' + return f"triton-cpu{'-ba' if blocked_a else ''}{'-ta' if transposed_a else ''}{'-bb' if blocked_b else ''}{'-tb' if transposed_b else ''}{'-prepack' if prepack else ''}{'-st' if single_thread else ''}-{dtype}" + + +def encode_torch_provider(single_thread, dtype): + assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' + return f"torch-cpu-native{'-st' if single_thread else ''}-{dtype}" + + +def decode_provider(provider): + if '-bfloat16' in provider: + dtype = torch.bfloat16 + if '-float16' in provider: + dtype = torch.float16 + elif '-float32' in provider: + dtype = torch.float32 + if 'triton-cpu' in provider: + backend = 'triton-cpu' + elif 'torch-cpu-native' in provider: + backend = 'torch-cpu-native' + elif 'torch-cpu-compile' in provider: + backend = 'torch-cpu-compile' + return backend, '-ba' in provider, '-ta' in provider, '-bb' in provider, '-tb' in provider, '-prepack' in provider, '-st' in provider, dtype + + +BLOCK_TRANSPOSE_A_OPTS = [(False, False)] +BLOCK_TRANSPOSE_B_OPTS = [(True, True), (False, False)] +PREPACK_OPTS = [False, True] +SINGLE_THREAD_OPTS = [False] +DTYPE_OPTS = ['float32'] +LINE_VALS = [ + encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype) + for single_thread in SINGLE_THREAD_OPTS + for blocked_a, transposed_a in BLOCK_TRANSPOSE_A_OPTS + for blocked_b, transposed_b in BLOCK_TRANSPOSE_B_OPTS + for prepack in PREPACK_OPTS + for dtype in DTYPE_OPTS + if blocked_a or blocked_b or not prepack +] + [encode_torch_provider(single_thread, dtype) for dtype in DTYPE_OPTS for single_thread in SINGLE_THREAD_OPTS] +LINE_NAMES = LINE_VALS +LINE_STYLES = None + +default_num_threads = torch.get_num_threads() + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[128 * i for i in range(2, 21)], # Different possible values for `x_name` + line_arg='provider', # Argument name whose value corresponds to a different line in the plot. + line_vals=LINE_VALS, # Possible values for `line_arg`. + line_names=LINE_NAMES, # Label name for the lines. + styles=LINE_STYLES, # Line styles. + ylabel='GFLOPS', # Label name for the y-axis. + plot_name= + # Name for the plot. Used also as a file name for saving the plot. + f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', + args={}, # Values for function arguments not in `x_names` and `y_name`. + )) +def benchmark(M, N, K, provider): + + device = 'cpu' if 'cpu' in provider else 'cuda' + backend, blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype = decode_provider(provider) + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((K, N), device=device, dtype=dtype) + + if single_thread: + torch.set_num_threads(1) + else: + torch.set_num_threads(default_num_threads) + + if backend == 'triton-cpu': + c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + a_tmp = torch.zeros((M * K + (M // BLOCK_SIZE_M) * (K // BLOCK_SIZE_K) * 64), device=device, dtype=dtype) + b_tmp = torch.zeros((K * N + (K // BLOCK_SIZE_K) * (N // BLOCK_SIZE_N) * 64), device=device, dtype=dtype) + c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + if prepack and (blocked_a or blocked_b): + grid = ((M // BLOCK_SIZE_M) * (N // BLOCK_SIZE_N), ) + block_transpose_combined_kernel[grid]( + a, a_tmp, b, b_tmp, # + M, N, K, # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # + GROUP_SIZE_M=GROUP_SIZE_M, # + BLOCKED_A=blocked_a, TRANSPOSED_BLOCK_A=transposed_a, # + BLOCKED_B=blocked_b, TRANSPOSED_B=transposed_b) + if blocked_a: + a = a_tmp + if blocked_b: + b = b_tmp + else: + c = torch.zeros((M, N), device=a.device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + if backend == 'torch-cpu-native': + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b, out=c), quantiles=quantiles) + elif backend == 'torch-cpu-compile': + compiled = torch.compile(torch.matmul) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles) + elif backend == 'triton-cpu': + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul_fma(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, + transposed_b, num_threads=int(single_thread)), quantiles=quantiles, + measure_time_with_hooks=True, rep=1000) + perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +# %% +# We can now run the decorated function above. Pass `print_data=True` to see the performance number, `show_plots=True` to plot them, and/or +# `save_path='/path/to/results/' to save them to disk along with raw CSV data: +benchmark.run(print_data=True, show_plots=True) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 3670ee298e15..ad26f6d37157 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -175,6 +175,8 @@ def make_tttcir(self, mod, metadata, opt): amx_fp16 = False amx_bf16 = 'amx-bf16' in self.cpu_features cpu.passes.ttcpuir.add_convert_dot_to_amx(pm, amx_int8, amx_fp16, amx_bf16) + if 'avx512f' in self.cpu_features: + cpu.passes.ttcpuir.add_convert_dot_to_fma(pm) cpu.passes.ttcpuir.add_convert_dot_generic(pm) promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features # We don't have any lowering for mixed precision matmuls, so always use casts for now diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index 7d3ecaf5c515..f0c7a777e5fa 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -36,6 +36,7 @@ createConvertDotProduct(bool useHorizontalSum); std::unique_ptr> createConvertDotToAMX(); std::unique_ptr> createConvertDotToAMX(bool convertInt8, bool convertFp16, bool convertBf16); +std::unique_ptr> createConvertDotToFMA(); std::unique_ptr> createConvertDotGeneric(); std::unique_ptr> createCanonicalize(); diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index 42df78e0bb12..00c01a4725ce 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -125,6 +125,18 @@ def ConvertDotToAMX : Pass<"triton-cpu-convert-dot-to-amx", "mlir::ModuleOp"> { "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertDotToFMA : Pass<"triton-cpu-convert-dot-to-fma", "mlir::ModuleOp"> { + let summary = "Decompose dot product op to a series of FMA operations."; + let description = [{ }]; + + let constructor = "mlir::triton::cpu::createConvertDotToFMA()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + def ConvertDotGeneric : Pass<"triton-cpu-convert-dot-generic", "mlir::ModuleOp"> { let summary = "Generic convertion of dot product op."; let description = [{ diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt index 17d2560e6187..c6e9b4ed69e6 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -1,6 +1,8 @@ add_triton_library(TritonCPUTransforms + ConvertDotOp/ConvertDotCommon.cpp ConvertDotOp/ConvertDotGeneric.cpp ConvertDotOp/ConvertDotToAMX.cpp + ConvertDotOp/ConvertDotToFMA.cpp Canonicalize.cpp ConvertDotProduct.cpp ConvertUnsupportedOps.cpp diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp new file mode 100644 index 000000000000..4ad5de863fb4 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp @@ -0,0 +1,190 @@ +#include "ConvertDotCommon.h" + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace mlir { +namespace triton { +namespace cpu { + +bool isLoopCarriedAcc(Value acc) { + LDBG("Check if accumulator can be held in tiles: " << acc); + if (!acc.hasOneUse()) { + LDBG(" No. Has multiple uses."); + for (auto op : acc.getUsers()) + LDBG(" " << *op); + return false; + } + + auto blockArg = dyn_cast(acc); + if (!blockArg) { + LDBG(" No. Not a block argument."); + return false; + } + + auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); + if (!forOp) { + LDBG(" No. Not in a for-loop."); + return false; + } + + blockArg.getArgNumber(); + + Value updAcc = acc.getUsers().begin()->getResult(0); + if (!updAcc.hasOneUse()) { + LDBG(" No. Has multiple uses."); + return false; + } + + auto &updAccUse = *updAcc.getUses().begin(); + if (!isa(updAccUse.getOwner()) || + updAccUse.getOperandNumber() != + (blockArg.getArgNumber() - forOp.getNumInductionVars())) { + LDBG(" No. Loop carried dependency not detected."); + return false; + } + + LDBG(" Yes."); + return true; +} + +Value getInitAccValue(Value val) { + auto blockArg = cast(val); + auto forOp = cast(blockArg.getOwner()->getParentOp()); + int initValIdx = blockArg.getArgNumber() - forOp.getNumInductionVars(); + return forOp.getInitArgs()[initValIdx]; +} + +MemBuffer findInputBuffer(Value val, bool allowTransposed) { + MemBuffer buf; + + if (allowTransposed) { + auto transposeOp = val.getDefiningOp(); + if (transposeOp) { + val = transposeOp.getVector(); + buf.transposed = true; + } + } + + auto valLoad = val.getDefiningOp(); + if (!valLoad || hasMaskOrBoundsCheck(valLoad)) { + LDBG("Couldn't find a buffer with input: " << val); + return buf; + } + + buf.memRef = valLoad.getSource(); + buf.indices = valLoad.getIndices(); + LLVM_DEBUG( + DBGS() << "Found buffer with input: " << val << "\n"; + DBGS() << " MemRef: " << buf.memRef << "\n"; DBGS() << " Indices: "; + llvm::interleaveComma(buf.indices, llvm::dbgs()); llvm::dbgs() << "\n"); + + auto forOp = dyn_cast(valLoad->getParentOp()); + if (!forOp) { + LDBG(" Skip steps. Not in a for-loop."); + return buf; + } + + auto extractMemRef = buf.memRef.getDefiningOp(); + if (!extractMemRef) { + LDBG(" Skip steps. No ExtractMemRefOp."); + return buf; + } + + ExtractIndicesOp extractIndices; + for (auto index : buf.indices) { + auto def = index.getDefiningOp(); + if (!def || (extractIndices && def != extractIndices)) { + LDBG(" Skip steps. No ExtractIndicesOp."); + return buf; + } + extractIndices = def; + } + + if (extractMemRef.getSrc() != extractIndices.getSrc()) { + LDBG(" Skip steps. Mismatched ExtractMemRefOp and ExtractIndicesOp."); + return buf; + } + + BlockArgument blockPtrArg = dyn_cast(extractMemRef.getSrc()); + if (!blockPtrArg) { + LDBG(" Skip steps. No block pointer arg."); + return buf; + } + + OpOperand *yieldOp = forOp.getTiedLoopYieldedValue(blockPtrArg); + if (!yieldOp) { + LDBG(" Skip steps. No block pointer in yield."); + return buf; + } + + auto advance = yieldOp->get().getDefiningOp(); + if (!advance) { + LDBG(" Skip steps. No AdvanceOp."); + return buf; + } + + if (advance.getPtr() != blockPtrArg) { + LDBG(" Skip steps. AdvanceOp doesn't use block pointer arg."); + return buf; + } + + buf.step = advance.getOffsets(); + LLVM_DEBUG(DBGS() << " Step: "; + llvm::interleaveComma(buf.step, llvm::dbgs()); + llvm::dbgs() << "\n"); + + return buf; +} + +Value maybeCast(Location loc, Value val, Type dstElemTy, + PatternRewriter &rewriter) { + VectorType srcTy = cast(val.getType()); + if (srcTy.getElementType() == dstElemTy) + return val; + + VectorType dstTy = srcTy.cloneWith(std::nullopt, dstElemTy); + if (srcTy.getElementType().isInteger()) { + if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) + return rewriter.create(loc, dstTy, val); + return rewriter.create(loc, dstTy, val); + } + + if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) + return rewriter.create(loc, dstTy, val); + return rewriter.create(loc, dstTy, val); +} + +MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, + Operation *allocaPoint, PatternRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(allocaPoint); + auto memRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); + Value memRef = rewriter.create( + loc, memRefTy, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); + Value zeroIdx = rewriter.create(loc, 0); + SmallVector indices(2, zeroIdx); + return {memRef, indices}; +} + +Value shiftIndex(Location loc, Value index, int64_t offs, + PatternRewriter &rewriter) { + if (!offs) + return index; + + // Do constant folding right away here for better code readability + // after the pass. + auto cstOp = dyn_cast(index.getDefiningOp()); + if (cstOp) { + int64_t oldVal = cast(cstOp.getValue()).getInt(); + return rewriter.create(loc, oldVal + offs); + } + + Value offsVal = rewriter.create(loc, offs); + return rewriter.create(loc, index.getType(), index, offsVal); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h new file mode 100644 index 000000000000..e26529d91882 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h @@ -0,0 +1,72 @@ +#include "cpu/include/TritonCPUTransforms/OptCommon.h" + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#define DEBUG_TYPE "triton-cpu-dot-conversion" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace mlir { +namespace triton { +namespace cpu { + +// This structure describes input/output buffer. +struct MemBuffer { + Value memRef; + SmallVector indices; + // If buffer is accessed in a loop and indices are advanced + // on each iteration, then step can hold those index offsets. + // Empty step doesn't mean indices are loop invariant. + SmallVector step; + // True if buffer holds transposed value. + bool transposed = false; + + bool empty() const { return !memRef; } +}; + +// Check if accumulator value is updated in a loop and has no other +// usages than a dot op that updates it. Loads, stores, and casts +// for such accumulator can be done outside of the loop. +bool isLoopCarriedAcc(Value acc); + +// Get initial value for a loop-carried accumulator. +Value getInitAccValue(Value val); + +// Check if vector transfer read/write operation uses a mask +// or involves a bounds check. +template bool hasMaskOrBoundsCheck(T op) { + auto inBounds = op.getInBounds(); + Value mask = op.getMask(); + bool hasBoundsCheck = + std::any_of(inBounds.begin(), inBounds.end(), [](Attribute attr) { + return !cast(attr).getValue(); + }); + return hasBoundsCheck || mask; +} + +// Search for a buffer holding required value. If allowTransposed is true, +// then buffer is allowed to hold both transposed and not transposed value. +// Return empty buffer if no memory holding value was found. +MemBuffer findInputBuffer(Value val, bool allowTransposed = false); + +// Cast vector to a specified element type using ext or trunc +// operations. Return the original value if it already matches +// the required element type. +Value maybeCast(Location loc, Value val, Type dstElemTy, + PatternRewriter &rewriter); + +// Allocate temporary buffer on stack for specified vector type. +MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, + Operation *allocaPoint, PatternRewriter &rewriter); + +// Move index by specified offset. Do constannt folding if possible. +Value shiftIndex(Location loc, Value index, int64_t offs, + PatternRewriter &rewriter); + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp new file mode 100644 index 000000000000..4d1832ca8cf9 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp @@ -0,0 +1,462 @@ +#include "ConvertDotCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTDOTTOFMA +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// This structure is used to hold candidates for conversion to FMA operations. +struct FmaDotOpCandidate { + // Operation to convert. + cpu::DotOp op; + // Here we keep actual element types used by LHS, RHS, and accumulator for + // computation. + Type lhsElemTy; + Type rhsElemTy; + Type accElemTy; + // Accumulator size. + int64_t accVecSize; + int64_t accRows; + // If accumulator is updated in a loop, then this flag indicates if we + // should keep it in registers the whole loop. + bool keepAccOnRegs = false; + // Memory buffer holding LHS. Can be empty if LHS is not a result of a + // simple load. + MemBuffer lhsBuf; + // Memory buffer holding RHS. Can be empty if RHS is not a result of a + // simple load. + MemBuffer rhsBuf; +}; + +// Check if input and output types can be handled by FMA (possibly, using +// additional casts for input/output). Returns true if FMA lowering is possible. +// In this case, element type fields of the candidate structure are filled +// with actual types to be used in lowering. +bool checkElemTypes(Type lhsElemTy, Type rhsElemTy, Type accElemTy, + Type resElemTy, FmaDotOpCandidate &candidate) { + MLIRContext *ctx = lhsElemTy.getContext(); + if (lhsElemTy.isInteger() || rhsElemTy.isInteger() || resElemTy.isInteger()) { + LDBG("Drop candidate because int types are not supported."); + return false; + } + + // Find a type to use for computations. Here we assume FMA works on FP32 + // and FP64, so smaller types are promoted. Flags should be added to cover + // other cases. + Type commonInputElemTy; + if (lhsElemTy.isF64() || rhsElemTy.isF64() || resElemTy.isF64()) + commonInputElemTy = Float64Type::get(ctx); + else + commonInputElemTy = Float32Type::get(ctx); + + candidate.lhsElemTy = commonInputElemTy; + candidate.rhsElemTy = commonInputElemTy; + candidate.accElemTy = commonInputElemTy; + + return true; +} + +// Check input shapes. Currently, support only 2D cases and ignore small +// inputs. +bool checkInputShapes(VectorType lhsTy, VectorType resTy) { + if (lhsTy.getRank() != 2) + return false; + + if (resTy.getDimSize(1) < 8) + return false; + + return true; +} + +// Check if specified ContractionOp can be lowered to FMA operations. +// If conversion is possible, then true is returned and candidate +// structure is filled with detailed transformation info. +bool isFmaCandidate(cpu::DotOp op, FmaDotOpCandidate &candidate) { + MLIRContext *ctx = op.getContext(); + VectorType lhsTy = op.getA().getType(); + VectorType rhsTy = op.getB().getType(); + VectorType accTy = op.getC().getType(); + VectorType resTy = op.getType(); + + LDBG("Considering candidate op: " << op); + + // Check if input and output types match available hardware capabilities. + // If check is successful then effective element types are assigned to the + // candidate. + if (!checkElemTypes(lhsTy.getElementType(), rhsTy.getElementType(), + accTy.getElementType(), resTy.getElementType(), + candidate)) + return false; + + // Check input shapes. + if (!checkInputShapes(lhsTy, resTy)) + return false; + + candidate.op = op; + candidate.accVecSize = resTy.getDimSize(1); + candidate.accRows = resTy.getDimSize(0); + candidate.keepAccOnRegs = isLoopCarriedAcc(op.getC()); + + if (lhsTy.getElementType() == candidate.lhsElemTy) + candidate.lhsBuf = findInputBuffer(op.getA(), true); + if (rhsTy.getElementType() == candidate.rhsElemTy) + candidate.rhsBuf = findInputBuffer(op.getB(), false); + + return true; +} + +MemBuffer storeToTmpBuffer(Location loc, Value val, Operation *allocaPoint, + PatternRewriter &rewriter) { + LDBG("Storing vector to a temporary buffer: " << val); + auto vecTy = cast(val.getType()); + MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); + rewriter.create(loc, val, buf.memRef, buf.indices); + return buf; +} + +SmallVector shiftIndices(Location loc, ArrayRef indices, + bool transposed, int64_t m, int64_t n, + PatternRewriter &rewriter) { + SmallVector res(indices.begin(), indices.end() - 2); + if (transposed) + std::swap(m, n); + res.push_back(shiftIndex(loc, *(indices.end() - 2), m, rewriter)); + res.push_back(shiftIndex(loc, *(indices.end() - 1), n, rewriter)); + return res; +} + +SmallVector shiftIndices(Location loc, const MemBuffer &buf, int64_t m, + int64_t n, PatternRewriter &rewriter) { + return shiftIndices(loc, buf.indices, buf.transposed, m, n, rewriter); +} + +Value loadRow(Location loc, VectorType resTy, const MemBuffer &buf, int64_t m, + PatternRewriter &rewriter) { + assert(!buf.empty()); + SmallVector indices = buf.indices; + indices[indices.size() - 2] = + shiftIndex(loc, indices[indices.size() - 2], m, rewriter); + return rewriter.create(loc, resTy, buf.memRef, indices); +} + +void storeRow(Location loc, const MemBuffer &buf, int64_t rowIdx, Value vec, + PatternRewriter &rewriter) { + SmallVector indices = buf.indices; + indices[indices.size() - 2] = + shiftIndex(loc, buf.indices[indices.size() - 2], rowIdx, rewriter); + rewriter.create(loc, vec, buf.memRef, indices); +} + +void storeRows(Location loc, const MemBuffer &buf, + const SmallVector &vecs, PatternRewriter &rewriter) { + SmallVector indices = buf.indices; + for (int64_t m = 0; m < vecs.size(); ++m) + storeRow(loc, buf, m, vecs[m], rewriter); +} + +SmallVector extractRows(Location loc, Value vec, + PatternRewriter &rewriter) { + VectorType vecTy = cast(vec.getType()); + SmallVector res; + for (int64_t m = 0; m < vecTy.getDimSize(0); ++m) { + auto row = + rewriter.create(loc, vec, SmallVector({m})); + res.push_back(row); + } + return res; +} + +Value mergeRows(Location loc, VectorType resTy, const SmallVector &tiles, + PatternRewriter &rewriter) { + Value res = + rewriter.create(loc, rewriter.getZeroAttr(resTy)); + for (int64_t m = 0; m < tiles.size(); ++m) + res = rewriter.create(loc, tiles[m], res, + SmallVector({m})); + return res; +} + +Value broadcastElem(Location loc, VectorType tileTy, const MemBuffer &buf, + int64_t m, int64_t n, PatternRewriter &rewriter) { + SmallVector indices = shiftIndices(loc, buf, m, n, rewriter); + Value scalar = rewriter.create(loc, buf.memRef, indices); + return rewriter.create(loc, tileTy, scalar); +} + +SmallVector computePrefetchIndices(Location loc, const MemBuffer &buf, + int64_t iters, + PatternRewriter &rewriter) { + SmallVector scaledStep; + Value itersVal; + for (auto step : buf.step) { + if (iters == 1) + scaledStep.push_back(rewriter.create( + loc, rewriter.getIndexType(), step)); + else if (auto cstOp = dyn_cast(step.getDefiningOp())) { + int64_t oldVal = cast(cstOp.getValue()).getInt(); + scaledStep.push_back( + rewriter.create(loc, oldVal * iters)); + } else { + if (!itersVal) + itersVal = + rewriter.create(loc, iters, step.getType()); + scaledStep.push_back(rewriter.create( + loc, rewriter.getIndexType(), + rewriter.create(loc, step.getType(), step, itersVal))); + } + } + + SmallVector res; + for (int64_t i = 0; i < scaledStep.size(); ++i) + res.push_back(rewriter.create( + loc, buf.indices[i].getType(), buf.indices[i], scaledStep[i])); + return res; +} + +void prefetch(Location loc, const MemBuffer &buf, int64_t m, int64_t n, + ArrayRef prefetchIndices, int64_t hint, + PatternRewriter &rewriter) { + SmallVector indices = + shiftIndices(loc, prefetchIndices, buf.transposed, m, n, rewriter); + rewriter.create(loc, buf.memRef, indices, false, hint, + true); +} + +LogicalResult convertCandidate(FmaDotOpCandidate &candidate, + PatternRewriter &rewriter) { + cpu::DotOp op = candidate.op; + Location loc = op.getLoc(); + VectorType lhsTy = cast(op.getA().getType()); + VectorType rhsTy = cast(op.getB().getType()); + VectorType accTy = cast(op.getC().getType()); + VectorType resTy = cast(op.getResult().getType()); + VectorType rhsVecTy = + VectorType::get(candidate.accVecSize, candidate.rhsElemTy); + VectorType accVecTy = + VectorType::get(candidate.accVecSize, candidate.accElemTy); + + Operation *allocaPoint = op; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + // Cast input data if required and prepare input buffer. It might be temporary + // buffers with stored vectors or the original input memory. + MemBuffer lhsBuf = candidate.lhsBuf; + if (lhsBuf.empty()) { + Value lhs = maybeCast(loc, op.getA(), candidate.lhsElemTy, rewriter); + lhsBuf = storeToTmpBuffer(loc, lhs, allocaPoint, rewriter); + } + + MemBuffer rhsBuf = candidate.rhsBuf; + if (rhsBuf.empty()) { + Value rhs = maybeCast(loc, op.getB(), candidate.rhsElemTy, rewriter); + rhsBuf = storeToTmpBuffer(loc, rhs, allocaPoint, rewriter); + } + + Value acc = maybeCast(loc, op.getC(), candidate.accElemTy, rewriter); + Value accToStore = acc; + scf::ForOp forOp; + if (candidate.keepAccOnRegs) { + forOp = cast(op->getParentOp()); + accToStore = getInitAccValue(acc); + } + + SmallVector accVecs; + SmallVector accInitVecs; + if (candidate.keepAccOnRegs) { + // Initial tile values are loaded before the loop and then directly + // used within the loop. Later, new iter values will be added to + // add loop carried-dependencies for accumulator tiles and accInitTiles + // will be used as initializers for them. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + LDBG("Loading accumulator to tiles before the loop."); + accInitVecs = extractRows(loc, accToStore, rewriter); + accVecs = accInitVecs; + } else { + accVecs = extractRows(loc, acc, rewriter); + } + + // Compute indices to be used by prefetch. + int64_t lhsPrefetchIters = + std::max(int64_t(128) / lhsTy.getNumElements(), int64_t(1)); + auto lhsPrefetchIndices = + computePrefetchIndices(loc, candidate.lhsBuf, lhsPrefetchIters, rewriter); + int64_t rhsPrefetchIters = + std::max(int64_t(128) / rhsTy.getNumElements(), int64_t(1)); + auto rhsPrefetchIndices = + computePrefetchIndices(loc, candidate.rhsBuf, rhsPrefetchIters, rewriter); + Value nextRhsVec = loadRow(loc, rhsVecTy, rhsBuf, 0, rewriter); + for (int64_t k = 0; k < lhsTy.getDimSize(1); ++k) { + Value rhsVec = nextRhsVec; + + // Load next vector in advance to hide load latency. + if (k != lhsTy.getDimSize(1) - 1) + nextRhsVec = loadRow(loc, rhsVecTy, rhsBuf, k + 1, rewriter); + + // Prefetch RHS to LLC cache. + if (!rhsPrefetchIndices.empty()) + prefetch(loc, candidate.rhsBuf, k, 0, rhsPrefetchIndices, 1, rewriter); + + Value nextLhsBroadcasted = + broadcastElem(loc, accVecTy, lhsBuf, 0, k, rewriter); + for (int64_t m = 0; m < candidate.accRows; ++m) { + Value lhsBroadcasted = nextLhsBroadcasted; + + // Load next value in advance to hide load latency. + if (m != candidate.accRows - 1) + nextLhsBroadcasted = + broadcastElem(loc, accVecTy, lhsBuf, m + 1, k, rewriter); + + // Prefetch LHS to L1 cache. + if (!lhsPrefetchIndices.empty()) { + if ((candidate.lhsBuf.transposed && (m % 8 == 0)) || + (!candidate.lhsBuf.transposed && (k % 8 == 0))) + prefetch(loc, candidate.lhsBuf, m, k, lhsPrefetchIndices, 3, + rewriter); + } + + accVecs[m] = rewriter.create(loc, rhsVec, lhsBroadcasted, + accVecs[m]); + } + } + + if (candidate.keepAccOnRegs) { + // In this case we have the whole accumulator/result on tiles. Loop + // carried dependencies are not in place yet and should be added. + // After the loop, resulting tiles should either be stored to the + // output buffer, or moved to a vector through a temporary buffer. + + // We don't need the original accumulator and contraction op anymore. + // Directly yield orig accumulator value, so it would be later removed + // as unused. The original contraction can be removed right away. + int64_t origResIdx = op.getResult().getUses().begin()->getOperandNumber(); + rewriter.replaceOp(op, op.getC()); + + // Now, replace the loop with a new one to add loop carried dependency for + // accumulator tiles. + LDBG("Rewrite loop to introduce loop carried dependencies for accumulator " + "tiles."); + SmallVector newInitOperands; + SmallVector newYieldedValues; + for (int64_t m = 0; m < candidate.accRows; ++m) { + LDBG("Initial value\n " << accInitVecs[m] << "\nis combined with\n " + << accVecs[m]); + newInitOperands.push_back(accInitVecs[m]); + newYieldedValues.push_back(accVecs[m]); + } + auto newForOp = cast(*forOp.replaceWithAdditionalYields( + rewriter, newInitOperands, true, + [&newYieldedValues](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return newYieldedValues; + })); + + // The resulting tiles are now in the new loop results. + auto resVecs = newForOp.getResults().take_back(newYieldedValues.size()); + for (int64_t m = 0; m < candidate.accRows; ++m) + accVecs[m] = resVecs[m]; + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(newForOp); + // Collect all results into a single vector. + LDBG("Merging resulting rows to replace loop result."); + VectorType resTy = accTy.cloneWith(std::nullopt, candidate.accElemTy); + Value newVal = mergeRows(loc, resTy, accVecs, rewriter); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceAllUsesWith(newForOp.getResult(origResIdx), newVal); + } else { + // The result is in the buffer. We should load it and replace the original + // constraction result. + LDBG("Merging resulting rows to replace orig op result."); + VectorType resTy = accTy.cloneWith(std::nullopt, candidate.accElemTy); + Value newVal = mergeRows(loc, resTy, accVecs, rewriter); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceOp(op, newVal); + } + + return success(); +} + +struct ConvertDotToFMA + : public triton::cpu::impl::ConvertDotToFMABase { + ConvertDotToFMA() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + SmallVector candidates; + mod->walk([this, &candidates](cpu::DotOp op) { + FmaDotOpCandidate candidate; + if (isFmaCandidate(op, candidate)) { + LLVM_DEBUG({ + LDBG("Found FMA candidate"); + LDBG(" Op: " << candidate.op); + LDBG(" LhsElemTy: " << candidate.lhsElemTy); + LDBG(" RhsElemTy: " << candidate.rhsElemTy); + LDBG(" AccElemTy: " << candidate.accElemTy); + LDBG(" AccVecSize: " << candidate.accVecSize); + LDBG(" AccRows: " << candidate.accRows); + LDBG(" KeepAccOnRegs: " << candidate.keepAccOnRegs); + if (!candidate.lhsBuf.empty()) { + LDBG(" LhsBuf: " << candidate.lhsBuf.memRef); + LDBG(" Transposed: " << candidate.lhsBuf.transposed); + } + if (!candidate.rhsBuf.empty()) { + LDBG(" RhsBuf: " << candidate.rhsBuf.memRef); + LDBG(" Transposed: " << candidate.rhsBuf.transposed); + } + }); + candidates.push_back(candidate); + } + return WalkResult::advance(); + }); + + for (auto &candidate : candidates) { + LDBG("Starting conversion of candidate: " << candidate.op); + PatternRewriter rewriter(context); + rewriter.setInsertionPoint(candidate.op); + if (succeeded(convertCandidate(candidate, rewriter))) { + LDBG("Conversion succeeded!"); + } else { + LDBG("Conversion failed!"); + } + } + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertDotToFMA() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 8206b021f357..50159cd94d33 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -91,6 +91,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { pm.addPass(mlir::triton::cpu::createConvertDotToAMX( convertInt8, convertFp16, convertBf16)); }); + m.def("add_convert_dot_to_fma", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createConvertDotToFMA()); + }); m.def("add_convert_dot_generic", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertDotGeneric()); }); From ee1bdc9ba9d8df7e8f5625cbc32ceeae82ab61c2 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Thu, 12 Dec 2024 16:31:18 -0600 Subject: [PATCH 154/165] AMX lowering improvements (#194) * Improve AMX lowering to minimize loads and stores. Signed-off-by: Ilya Enkovich * Support bfloat16 in CPU matmul tutorials. Signed-off-by: Ilya Enkovich --------- Signed-off-by: Ilya Enkovich --- .../tutorials/03-matrix-multiplication-cpu.py | 4 +- python/tutorials/cpu-blocked-matmul-fp32.py | 58 +-- .../include/TritonCPUTransforms/OptCommon.h | 11 + .../ConvertDotOp/ConvertDotToAMX.cpp | 430 +++++++++--------- 4 files changed, 261 insertions(+), 242 deletions(-) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index f2ee03dfadc2..3b44a30bf7ad 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -156,7 +156,7 @@ import os DTYPE = getattr(torch, (os.getenv("DTYPE", "float32"))) -# Chosse block size depending on dtype. We have more register +# Choose block size depending on dtype. We have more register # capacity for bfloat16/float16 compared to float32. BLOCK_SIZE_M = 8 if DTYPE == torch.float32 else 32 BLOCK_SIZE_N = 32 @@ -306,7 +306,7 @@ def matmul_preprocess_input(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, n K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" if c is None: # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=a.dtype) + c = torch.empty((M, N), device=a.device, dtype=torch.float32) else: assert c.shape == (M, N), "Incompatible dimensions" diff --git a/python/tutorials/cpu-blocked-matmul-fp32.py b/python/tutorials/cpu-blocked-matmul-fp32.py index 0df8f41b9b2b..8f0f0ebce41a 100644 --- a/python/tutorials/cpu-blocked-matmul-fp32.py +++ b/python/tutorials/cpu-blocked-matmul-fp32.py @@ -15,10 +15,14 @@ import triton import triton.language as tl +import os -BLOCK_SIZE_M = 8 +DTYPE = os.getenv("DTYPE", "float32") +# Choose block size depending on dtype. We have more register +# capacity for bfloat16/float16 compared to float32. +BLOCK_SIZE_M = 8 if DTYPE == "float32" else 32 BLOCK_SIZE_N = 32 -BLOCK_SIZE_K = 8 +BLOCK_SIZE_K = 8 if DTYPE == "float32" else 32 GROUP_SIZE_M = 8 @@ -39,13 +43,10 @@ # used by Triton CPU backend which processes RHS block row-by-row and LHS # block column-by-column. @triton.jit -def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, - TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr, +def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): - tl.static_assert(M % BLOCK_SIZE_M == 0) - tl.static_assert(N % BLOCK_SIZE_N == 0) tl.static_assert(BLOCKED_A or not TRANSPOSED_BLOCK_A) tl.static_assert(BLOCKED_B or not TRANSPOSED_B) pid = tl.program_id(axis=0) @@ -62,10 +63,10 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N a_out_block_m = in_block_m A_OUT_BLOCK_SIZE_M: tl.constexpr = BLOCK_SIZE_K if TRANSPOSED_BLOCK_A else BLOCK_SIZE_M A_OUT_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_M if TRANSPOSED_BLOCK_A else BLOCK_SIZE_K - A_OUT_BLOCKS_M: tl.constexpr = M // BLOCK_SIZE_M - A_OUT_BLOCKS_K: tl.constexpr = K // BLOCK_SIZE_K + A_OUT_BLOCKS_M = M // BLOCK_SIZE_M + A_OUT_BLOCKS_K = K // BLOCK_SIZE_K A_OUT_STRIDE_M: tl.constexpr = A_OUT_BLOCK_SIZE_K - A_OUT_STRIDE_BLOCK_M: tl.constexpr = BLOCK_SIZE_M * K + A_OUT_STRIDE_BLOCK_M = BLOCK_SIZE_M * K A_OUT_STRIDE_BLOCK_K: tl.constexpr = BLOCK_SIZE_M * BLOCK_SIZE_K for in_block_k in tl.range(in_block_n, A_OUT_BLOCKS_K, N // BLOCK_SIZE_N): a_out_block_k = in_block_k @@ -84,10 +85,10 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N tl.store(a_out_ptr, val) if BLOCKED_B: - B_OUT_BLOCKS_K: tl.constexpr = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K - B_OUT_BLOCKS_N: tl.constexpr = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N + B_OUT_BLOCKS_K = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K + B_OUT_BLOCKS_N = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N B_OUT_STRIDE_K: tl.constexpr = BLOCK_SIZE_N - B_OUT_STRIDE_BLOCK_K: tl.constexpr = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N) + B_OUT_STRIDE_BLOCK_K = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N) B_OUT_STRIDE_BLOCK_N: tl.constexpr = BLOCK_SIZE_K * BLOCK_SIZE_N for in_block_k in tl.range(in_block_m, K // BLOCK_SIZE_K, M // BLOCK_SIZE_M): b_out_block_k = in_block_n if TRANSPOSED_B else in_block_k @@ -120,11 +121,11 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M: tl.constexpr, N # Reshape is used to remove the heading (1, 1) dimensions, but CPU backend folds it with the load # operation and it doesn't prevent direct vector loads from the input memory. @triton.jit -def matmul_kernel_fma(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - # number of blocks in a group - GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, - BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): +def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + # number of blocks in a group + GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, + BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): # TRANSPOSED_BLOCK_A means that each block in A is transposed. # It is allowed only for blocked input. assert (BLOCKED_A or not TRANSPOSED_BLOCK_A) @@ -188,8 +189,8 @@ def matmul_kernel_fma(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, tl.store(c_block_ptr, c) -def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, - PREPACKED, BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0): +def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, PREPACKED, + BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0): #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" @@ -207,7 +208,7 @@ def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tens a = ab if BLOCKED_B: b = bb - matmul_kernel_fma[grid]( + matmul_kernel[grid]( a, b, c, # M, N, K, # BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # @@ -233,14 +234,14 @@ def matmul_fma(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tens rtol = 0 a_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_M) * (512 // BLOCK_SIZE_K) * 64), device='cpu', dtype=torch.float32) b_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_K) * (512 // BLOCK_SIZE_N) * 64), device='cpu', dtype=torch.float32) -triton_output = matmul_fma(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False) +triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False) if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU and TorchCPU match") else: print("❌ TritonCPU and TorchCPU differ, the maximum difference is " f'{torch.max(torch.abs(triton_output - torch_output))}') assert False -triton_output = matmul_fma(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True) +triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True) if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU pre-packed and TorchCPU match") else: @@ -289,7 +290,7 @@ def decode_provider(provider): BLOCK_TRANSPOSE_B_OPTS = [(True, True), (False, False)] PREPACK_OPTS = [False, True] SINGLE_THREAD_OPTS = [False] -DTYPE_OPTS = ['float32'] +DTYPE_OPTS = [DTYPE] LINE_VALS = [ encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype) for single_thread in SINGLE_THREAD_OPTS @@ -316,7 +317,7 @@ def decode_provider(provider): ylabel='GFLOPS', # Label name for the y-axis. plot_name= # Name for the plot. Used also as a file name for saving the plot. - f'matmul-performance-fp32 (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', + f'matmul-performance-{DTYPE} (BLOCK_SIZE_M={BLOCK_SIZE_M}, BLOCK_SIZE_N={BLOCK_SIZE_N}, BLOCK_SIZE_K={BLOCK_SIZE_K}, GROUP_SIZE_M={GROUP_SIZE_M})', args={}, # Values for function arguments not in `x_names` and `y_name`. )) def benchmark(M, N, K, provider): @@ -360,9 +361,8 @@ def benchmark(M, N, K, provider): ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled(a, b, out=c), quantiles=quantiles) elif backend == 'triton-cpu': ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul_fma(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, - transposed_b, num_threads=int(single_thread)), quantiles=quantiles, - measure_time_with_hooks=True, rep=1000) + lambda: matmul(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, transposed_b, + num_threads=int(single_thread)), quantiles=quantiles, measure_time_with_hooks=True, rep=1000) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h index a2e94f894caf..09e7ec65595d 100644 --- a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h +++ b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h @@ -125,12 +125,14 @@ inline Value shapeCast(Location loc, Value in, } // namespace mlir #define int_cst(ty, val) intCst(loc, ty, val, rewriter) +#define index_cst(val) rewriter.create(loc, val) #define cst_like(src, val) cstLike(loc, src, val, rewriter) #define op_addi(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_addf(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_subi(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_subf(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_muli(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_mulf(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_bitcast(ty, val) rewriter.create(loc, ty, val) #define op_lshr(lhs, rhs) rewriter.create(loc, lhs, rhs) @@ -146,6 +148,15 @@ inline Value shapeCast(Location loc, Value in, rewriter.create(loc, cond, val, other) #define op_sitofp(ty, val) rewriter.create(loc, ty, val) #define op_fptosi(ty, val) rewriter.create(loc, ty, val) +#define op_read(ty, memRef, indices) \ + rewriter.create(loc, ty, memRef, indices) +#define op_write(val, memRef, indices) \ + rewriter.create(loc, val, memRef, indices) +#define op_interleave(lhs, rhs) \ + rewriter.create(loc, lhs, rhs) +#define op_extract(vec, idx) rewriter.create(loc, vec, idx) +#define op_store(val, mem, idx) \ + rewriter.create(loc, val, mem, idx) #define op_icmp_eq(lhs, rhs) \ rewriter.create(loc, arith::CmpIPredicate::eq, lhs, rhs) diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp index 23f16944de41..1b6dd9269ac1 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp @@ -1,4 +1,4 @@ -#include "cpu/include/TritonCPUTransforms/OptCommon.h" +#include "ConvertDotCommon.h" #include "cpu/include/TritonCPUTransforms/Passes.h" @@ -24,24 +24,12 @@ namespace cpu { } // namespace triton } // namespace mlir -#define DEBUG_TYPE "triton-cpu-dot-to-amx" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - using namespace mlir; using namespace mlir::triton; using namespace mlir::triton::cpu; namespace { -// This struct describes buffers used to load/store AMX tiles. -struct AmxBuffer { - Value memRef; - SmallVector indices; - - bool empty() const { return !memRef; } -}; - // This structure is used to hold candidates for conversion to AMX // Mul[F|I]Op operations. struct AmxDotOpCandidate { @@ -75,7 +63,7 @@ struct AmxDotOpCandidate { // If resulting tiles are not required to be trasfered to vectors and can be // directly stored to the output memory instead, then this field holds a // buffer to use. - AmxBuffer outBuf; + MemBuffer outBuf; // If output buffer is used then keep the original vector store here. Operation *origStore = nullptr; }; @@ -182,50 +170,6 @@ bool checkInputShapes(VectorType lhsTy, VectorType resTy) { return true; } -// Check if accumulator value is updated in a loop and has no other -// usages than a dot op, that updates it. Tile loads/stores and casts -// for such accumulators can be done outside of the loop. -bool isLoopCarriedAcc(Value acc) { - LDBG("Check if accumulator can be held in tiles: " << acc); - if (!acc.hasOneUse()) { - LDBG(" No. Has multiple uses."); - for (auto op : acc.getUsers()) - LDBG(" " << *op); - return false; - } - - auto blockArg = dyn_cast(acc); - if (!blockArg) { - LDBG(" No. Not a block argument."); - return false; - } - - auto forOp = dyn_cast(blockArg.getOwner()->getParentOp()); - if (!forOp) { - LDBG(" No. Not in a for-loop."); - return false; - } - - blockArg.getArgNumber(); - - Value updAcc = acc.getUsers().begin()->getResult(0); - if (!updAcc.hasOneUse()) { - LDBG(" No. Has multiple uses."); - return false; - } - - auto &updAccUse = *updAcc.getUses().begin(); - if (!isa(updAccUse.getOwner()) || - updAccUse.getOperandNumber() != - (blockArg.getArgNumber() - forOp.getNumInductionVars())) { - LDBG(" No. Loop carried dependency not detected."); - return false; - } - - LDBG(" Yes."); - return true; -} - // Return a value that holds the resulting loop carried accumulator value. // It's one of ForOp's results. Value getResValueForLoopCarriedAcc(cpu::DotOp op) { @@ -239,11 +183,11 @@ Value getResValueForLoopCarriedAcc(cpu::DotOp op) { // by input shapes and types. Block sizes are chosen to minimize number of // tile loads/stores including tile register spills. void setupBlockAndTileSizes(ArrayRef lhsShape, - ArrayRef rhsShape, + ArrayRef resShape, AmxDotOpCandidate &candidate) { - int64_t m = lhsShape[0]; - int64_t n = rhsShape[1]; - int64_t k = rhsShape[0]; + int64_t m = resShape[0]; + int64_t n = resShape[1]; + int64_t k = lhsShape[1]; int64_t tileM = std::min(m, (int64_t)16); int64_t tileN = std::min(n, (int64_t)16); int64_t tileK = std::min( @@ -288,7 +232,7 @@ void findOutputBuffer(Value val, AmxDotOpCandidate &candidate) { if (val.hasOneUse()) { auto store = dyn_cast(*val.user_begin()); if (store && !hasMaskOrBoundsCheck(store)) - candidate.outBuf = AmxBuffer{store.getSource(), store.getIndices()}; + candidate.outBuf = MemBuffer{store.getSource(), store.getIndices()}; candidate.origStore = store; } } @@ -319,15 +263,16 @@ bool isAmxCandidate(cpu::DotOp op, bool supportInt8, bool supportFp16, return false; candidate.op = op; - setupBlockAndTileSizes(lhsTy.getShape(), rhsTy.getShape(), candidate); + setupBlockAndTileSizes(lhsTy.getShape(), resTy.getShape(), candidate); candidate.keepAccOnTiles = isLoopCarriedAcc(op.getC()); // Can't keep acc in a tile the whole loop right now: // https://github.com/llvm/llvm-project/issues/109481 if (candidate.keepAccOnTiles) { - // We might not have enough tiles to hold accumulator. In this case - // keep it in a bufffer. - if (candidate.tilesInBlockM * candidate.tilesInBlockN > 1) { + // We might not have enough tiles to hold the whole accumulator. If we + // have more than one block, keep it in a bufffer. + if (candidate.tilesInBlockM * candidate.tileM < resTy.getDimSize(0) || + candidate.tilesInBlockN * candidate.tileN < resTy.getDimSize(1)) { LDBG("Accumulator is too big to keep on tiles. Keep it bufferized " "insterad."); candidate.keepAccOnTiles = false; @@ -335,14 +280,6 @@ bool isAmxCandidate(cpu::DotOp op, bool supportInt8, bool supportFp16, } else { findOutputBuffer(getResValueForLoopCarriedAcc(op), candidate); } - - // TODO: fix LLVM bug and remove this code. - LDBG("Avoid accumulator on tiles due to LLVM bug: " - "https://github.com/llvm/llvm-project/issues/109481."); - LDBG("Keep accumulator bufferized instead."); - candidate.keepAccOnTiles = false; - candidate.keepAccInBuf = true; - candidate.outBuf = AmxBuffer{}; } else { findOutputBuffer(op.getResult(), candidate); } @@ -350,35 +287,6 @@ bool isAmxCandidate(cpu::DotOp op, bool supportInt8, bool supportFp16, return true; } -// Cast vector to a specified element type using ext or trunc -// operations. Return the original value if it already matches -// the required element type. -Value maybeCast(Location loc, Value val, Type dstElemTy, - PatternRewriter &rewriter) { - VectorType srcTy = cast(val.getType()); - if (srcTy.getElementType() == dstElemTy) - return val; - - VectorType dstTy = srcTy.cloneWith(std::nullopt, dstElemTy); - if (srcTy.getElementType().isInteger()) { - if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) - return rewriter.create(loc, dstTy, val); - return rewriter.create(loc, dstTy, val); - } - - if (srcTy.getElementTypeBitWidth() < dstTy.getElementTypeBitWidth()) - return rewriter.create(loc, dstTy, val); - return rewriter.create(loc, dstTy, val); -} - -// Get initial value for a loop-carried accumulator. -Value getInitAccValue(Value val) { - auto blockArg = cast(val); - auto forOp = cast(blockArg.getOwner()->getParentOp()); - int initValIdx = blockArg.getArgNumber() - forOp.getNumInductionVars(); - return forOp.getInitArgs()[initValIdx]; -} - template T getSwizzledRhsTileType(T origTileType) { int64_t rowsPerGroup = 32 / origTileType.getElementTypeBitWidth(); SmallVector shape({origTileType.getDimSize(0) / rowsPerGroup, @@ -386,18 +294,6 @@ template T getSwizzledRhsTileType(T origTileType) { return origTileType.cloneWith(shape, origTileType.getElementType()); } -AmxBuffer allocateTmpBuffer(Location loc, VectorType vecTy, - Operation *allocaPoint, PatternRewriter &rewriter) { - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(allocaPoint); - auto memRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); - Value memRef = rewriter.create( - loc, memRefTy, rewriter.getIntegerAttr(rewriter.getI64Type(), 64)); - Value zeroIdx = rewriter.create(loc, 0); - SmallVector indices(2, zeroIdx); - return {memRef, indices}; -} - // In AMX, element values shoud be packed to 32-bit groups that would be // multiplied elementwise with following accumulation. It means that RHS // needs to be pre-packed. E.g. for the following input @@ -418,27 +314,80 @@ void interleaveAndStore(Location loc, Value val, Value buf, int64_t rowsPerGroup = 32 / valTy.getElementTypeBitWidth(); assert(rowsPerGroup == 2 || rowsPerGroup == 4); assert(valTy.getDimSize(0) % rowsPerGroup == 0); - Value zeroIdx = rewriter.create(loc, 0); + Value zeroIdx = index_cst(0); for (int64_t i = 0; i < valTy.getDimSize(0); i += rowsPerGroup) { Value row1, row2; if (rowsPerGroup == 2) { - row1 = rewriter.create(loc, val, i); - row2 = rewriter.create(loc, val, i + 1); + row1 = op_extract(val, i); + row2 = op_extract(val, i + 1); } else { - row1 = rewriter.create( - loc, rewriter.create(loc, val, i), - rewriter.create(loc, val, i + 2)); - row2 = rewriter.create( - loc, rewriter.create(loc, val, i + 1), - rewriter.create(loc, val, i + 3)); + row1 = op_interleave(op_extract(val, i), op_extract(val, i + 2)); + row2 = op_interleave(op_extract(val, i + 1), op_extract(val, i + 3)); } - Value shuffled = rewriter.create(loc, row1, row2); - Value idx = rewriter.create(loc, i / rowsPerGroup); - rewriter.create(loc, shuffled, buf, - SmallVector({idx, zeroIdx})); + Value shuffled = op_interleave(row1, row2); + Value idx = index_cst(i / rowsPerGroup); + op_store(shuffled, buf, SmallVector({idx, zeroIdx})); } } +Value loadWithPrefetch(Location loc, VectorType ty, Value memRef, + ArrayRef indices, ArrayRef step, + PatternRewriter &rewriter) { + Value res = op_read(ty, memRef, indices); + if (!step.empty()) { + SmallVector prefetchIndices; + for (int64_t i = 0; i < indices.size(); ++i) { + prefetchIndices.push_back( + op_addi(indices[i], rewriter.create( + loc, rewriter.getIndexType(), step[i]))); + } + rewriter.create(loc, memRef, prefetchIndices, false, 1, + true); + } + return res; +} + +// Copy tensor with packing using for-loop. See interleaveAndStore for more +// details. +void copyWithInterleave(Location loc, VectorType srcTy, const MemBuffer &src, + const MemBuffer &dst, PatternRewriter &rewriter) { + int64_t rowsPerGroup = 32 / srcTy.getElementTypeBitWidth(); + Value lower = index_cst(0); + Value upper = index_cst(srcTy.getDimSize(0) / rowsPerGroup); + Value one = index_cst(1); + Value rowsPerGroupVal = index_cst(rowsPerGroup); + VectorType srcVecTy = + VectorType::get({srcTy.getDimSize(1)}, srcTy.getElementType()); + auto forOp = rewriter.create(loc, lower, upper, one); + Value ivVal = forOp.getInductionVar(); + rewriter.setInsertionPointToStart(forOp.getBody()); + SmallVector srcIndices = src.indices; + int64_t mDimIdx = srcIndices.size() - 2; + Value scaledM = op_muli(ivVal, rowsPerGroupVal); + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], scaledM); + Value row1 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, src.step, + rewriter); + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], one); + Value row2 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, src.step, + rewriter); + if (rowsPerGroup == 4) { + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], one); + Value row3 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, + src.step, rewriter); + srcIndices[mDimIdx] = op_addi(srcIndices[mDimIdx], one); + Value row4 = loadWithPrefetch(loc, srcVecTy, src.memRef, srcIndices, + src.step, rewriter); + row1 = op_interleave(row1, row3); + row2 = op_interleave(row2, row4); + } + Value shuffled = op_interleave(row1, row2); + SmallVector dstIndices = dst.indices; + dstIndices[dstIndices.size() - 2] = + op_addi(dstIndices[dstIndices.size() - 2], ivVal); + op_write(shuffled, dst.memRef, dstIndices); + rewriter.setInsertionPointAfter(forOp); +} + // Prepare temporary buffers to be used for tile loads. If the original // value can be directly loaded to tiles from its original memory, then // use it instead. Return empty buffer if source value is all zeros and @@ -446,18 +395,25 @@ void interleaveAndStore(Location loc, Value val, Value buf, // // If interleave flag is set, then pre-pack RHS before store. See // interleaveAndStore for more details. -AmxBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, +MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, bool skipForZeros, bool readOnly, Operation *allocaPoint, PatternRewriter &rewriter) { LDBG("Preparing buffer (interleave=" << interleave << ") for a vector: " << val); - auto valLoad = val.getDefiningOp(); - if (valLoad && !interleave && readOnly && !hasMaskOrBoundsCheck(valLoad)) { - Value memRef = valLoad.getSource(); - ValueRange indices = valLoad.getIndices(); - LDBG(" Reusing the original memref for a buffer: " << memRef); - return {memRef, indices}; + auto vecTy = cast(val.getType()); + MemBuffer inputBuf = findInputBuffer(val); + if (!inputBuf.empty()) { + if (interleave) { + LDBG(" Copying from the original memref with interleave: " + << inputBuf.memRef); + auto tmpBuf = allocateTmpBuffer(loc, getSwizzledRhsTileType(vecTy), + allocaPoint, rewriter); + copyWithInterleave(loc, vecTy, inputBuf, tmpBuf, rewriter); + return tmpBuf; + } + LDBG(" Reusing the original memref for a buffer: " << inputBuf.memRef); + return inputBuf; } if (skipForZeros && isZeroConst(val)) { @@ -465,15 +421,14 @@ AmxBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, return {}; } - auto vecTy = cast(val.getType()); if (interleave) vecTy = getSwizzledRhsTileType(vecTy); - AmxBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); + MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); if (interleave) { interleaveAndStore(loc, val, buf.memRef, rewriter); } else { - rewriter.create(loc, val, buf.memRef, buf.indices); + op_write(val, buf.memRef, buf.indices); } return buf; @@ -482,8 +437,8 @@ AmxBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, // Return a buffer where the final result should be stored. If result can // be directly stored to the output memory, then it is used as an output // buffer. Otherwise, re-use accumulator buffer or create a new one. -AmxBuffer prepareResultBuffer(Location loc, Value val, const AmxBuffer &accBuf, - const AmxBuffer &outBuf, Operation *allocaPoint, +MemBuffer prepareResultBuffer(Location loc, Value val, const MemBuffer &accBuf, + const MemBuffer &outBuf, Operation *allocaPoint, PatternRewriter &rewriter) { if (!outBuf.empty()) { LDBG("Output memory will be used for direct tile stores."); @@ -500,37 +455,22 @@ AmxBuffer prepareResultBuffer(Location loc, Value val, const AmxBuffer &accBuf, rewriter); } -Value shiftIndex(Location loc, Value index, int64_t offs, - PatternRewriter &rewriter) { - if (!offs) - return index; - - // Do constant folding right away here for better code readability - // after the pass. - auto cstOp = dyn_cast(index.getDefiningOp()); - if (cstOp) { - int64_t oldVal = cast(cstOp.getValue()).getInt(); - return rewriter.create(loc, oldVal + offs); - } - - Value offsVal = rewriter.create(loc, offs); - return rewriter.create(loc, index.getType(), index, offsVal); -} - -SmallVector shiftIndices(Location loc, ArrayRef indices, - amx::TileType tileTy, int64_t tilesInBlockM, - int64_t tilesInBlockN, int64_t blockM, - int64_t blockN, int64_t tileM, int64_t tileN, - PatternRewriter &rewriter) { +SmallVector shiftIndices(Location loc, ArrayRef indices, + amx::TileType tileTy, int64_t tilesInBlockM, + int64_t tilesInBlockN, int64_t blockM, + int64_t blockN, int64_t tileM, int64_t tileN, + PatternRewriter &rewriter) { int64_t blockOffsM = blockM * tilesInBlockM * tileTy.getDimSize(0); int64_t blockOffsN = blockN * tilesInBlockN * tileTy.getDimSize(1); int64_t tileOffsM = blockOffsM + tileM * tileTy.getDimSize(0); int64_t tileOffsN = blockOffsN + tileN * tileTy.getDimSize(1); - return {shiftIndex(loc, indices[0], tileOffsM, rewriter), - shiftIndex(loc, indices[1], tileOffsN, rewriter)}; + SmallVector res(indices.begin(), indices.end() - 2); + res.push_back(shiftIndex(loc, *(indices.end() - 2), tileOffsM, rewriter)); + res.push_back(shiftIndex(loc, *(indices.end() - 1), tileOffsN, rewriter)); + return res; } -Value loadTile(Location loc, amx::TileType tileTy, const AmxBuffer &buf, +Value loadTile(Location loc, amx::TileType tileTy, const MemBuffer &buf, int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, int64_t blockN, int64_t tileM, int64_t tileN, PatternRewriter &rewriter) { @@ -541,7 +481,7 @@ Value loadTile(Location loc, amx::TileType tileTy, const AmxBuffer &buf, } void storeTile(Location loc, amx::TileType tileTy, Value val, - const AmxBuffer &buf, int64_t tilesInBlockM, + const MemBuffer &buf, int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, int64_t blockN, int64_t tileM, int64_t tileN, PatternRewriter &rewriter) { auto indices = @@ -551,7 +491,7 @@ void storeTile(Location loc, amx::TileType tileTy, Value val, } SmallVector> -loadBlockTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, +loadBlockTiles(Location loc, amx::TileType tileTy, const MemBuffer &buf, int64_t tilesInBlockM, int64_t tilesInBlockN, int64_t blockM, int64_t blockN, PatternRewriter &rewriter) { SmallVector> res(tilesInBlockM); @@ -567,22 +507,18 @@ loadBlockTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, return res; } -// Move acc to a tile for the whole loop. It might be loads from memory or -// zero tiles. -SmallVector> -moveLoopAccToTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, - int64_t tilesInBlockM, int64_t tilesInBlockN, - PatternRewriter &rewriter) { - LDBG("Loading accumulator to tiles before the loop."); - auto res = loadBlockTiles(loc, tileTy, buf, tilesInBlockM, tilesInBlockN, 0, - 0, rewriter); - - // TODO: add new block args into ForOp and return them instead. - // Yield directly uses them for now and will be patched after mul - // ops generation. - llvm_unreachable("Not yet supported."); - - return res; +void storeBlockTiles(Location loc, amx::TileType tileTy, const MemBuffer &buf, + int64_t blockM, int64_t blockN, + const SmallVector> &tiles, + PatternRewriter &rewriter) { + int64_t tilesInBlockM = tiles.size(); + int64_t tilesInBlockN = tiles[0].size(); + for (int64_t m = 0; m < tilesInBlockM; ++m) { + for (int64_t n = 0; n < tilesInBlockN; ++n) { + storeTile(loc, tileTy, tiles[m][n], buf, tilesInBlockM, tilesInBlockN, + blockM, blockN, m, n, rewriter); + } + } } // Multiply two blocks. LHS block is preloaded to tiles with the following @@ -590,8 +526,8 @@ moveLoopAccToTiles(Location loc, amx::TileType tileTy, const AmxBuffer &buf, // Optionally, results can also be stored to accBuf. void multiplyBlocksPreloadLhs(Location loc, amx::TileType lhsTileTy, amx::TileType rhsTileTy, amx::TileType accTileTy, - const AmxBuffer &lhsBuf, const AmxBuffer &rhsBuf, - const AmxBuffer &accBuf, int64_t blockM, + const MemBuffer &lhsBuf, const MemBuffer &rhsBuf, + const MemBuffer &accBuf, int64_t blockM, int64_t blockN, int64_t blockK, int64_t tilesInBlockM, int64_t tilesInBlockN, SmallVector> &accTiles, @@ -626,8 +562,8 @@ void multiplyBlocksPreloadLhs(Location loc, amx::TileType lhsTileTy, // Similar to multiplyBlocksPreloadLhs but here RHS is preloaded to tiles. void multiplyBlocksPreloadRhs(Location loc, amx::TileType lhsTileTy, amx::TileType rhsTileTy, amx::TileType accTileTy, - const AmxBuffer &lhsBuf, const AmxBuffer &rhsBuf, - const AmxBuffer &accBuf, int64_t blockM, + const MemBuffer &lhsBuf, const MemBuffer &rhsBuf, + const MemBuffer &accBuf, int64_t blockM, int64_t blockN, int64_t blockK, int64_t tilesInBlockM, int64_t tilesInBlockN, SmallVector> &accTiles, @@ -691,11 +627,11 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, // Cast input data if required and prepare input buffer. It might be temporary // buffers with stored vectors or the original input memory. Value lhs = maybeCast(loc, op.getA(), candidate.lhsTileElemTy, rewriter); - AmxBuffer lhsBuf = + MemBuffer lhsBuf = prepareTensorBuffer(loc, lhs, false, false, true, allocaPoint, rewriter); Value rhs = maybeCast(loc, op.getB(), candidate.rhsTileElemTy, rewriter); - AmxBuffer rhsBuf = + MemBuffer rhsBuf = prepareTensorBuffer(loc, rhs, true, false, true, allocaPoint, rewriter); Value acc = maybeCast(loc, op.getC(), candidate.accTileElemTy, rewriter); @@ -705,7 +641,7 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, forOp = cast(op->getParentOp()); accToStore = getInitAccValue(acc); } - AmxBuffer accBuf; + MemBuffer accBuf; { // If accumulator is bufferized then we should move initial values before // the loop. @@ -717,14 +653,24 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, false, allocaPoint, rewriter); } - AmxBuffer resBuf = prepareResultBuffer( + MemBuffer resBuf = prepareResultBuffer( loc, op.getResult(), accBuf, candidate.outBuf, allocaPoint, rewriter); SmallVector> accTiles; - if (candidate.keepAccOnTiles) - accTiles = - moveLoopAccToTiles(loc, accTileTy, accBuf, candidate.tilesInBlockM, - candidate.tilesInBlockN, rewriter); + SmallVector> accInitTiles; + if (candidate.keepAccOnTiles) { + // Initial tile values are loaded before the loop and then directly + // used within the loop. Later, new iter values will be added to + // add loop carried-dependencies for accumulator tiles and accInitTiles + // will be used as initializers for them. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(forOp); + LDBG("Loading accumulator to tiles before the loop."); + accInitTiles = + loadBlockTiles(loc, accTileTy, accBuf, candidate.tilesInBlockM, + candidate.tilesInBlockN, 0, 0, rewriter); + accTiles = accInitTiles; + } int64_t blocksInAccM = accTy.getDimSize(0) / candidate.tileM / candidate.tilesInBlockM; @@ -743,6 +689,7 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, // TODO: enable forward store for acc kept in tiles. bool storeAcc = !candidate.keepAccOnTiles && (blocK == (tilesInVectorK - 1)); + // We need to choose which block (LHS or RHS) to keep on tiles. // E.g. for ACC block 4x1 tiles, LHS block is also 4 tiles, so // we would use all tile registers trying to keep both ACC and @@ -762,37 +709,98 @@ LogicalResult convertCandidate(AmxDotOpCandidate &candidate, } } - // TODO: For keepAccOnTiles fix YieldOp to use mul results. - // TODO: For keepAccOnTiles move all new forOp results to vector through a - // buffer. - if (candidate.keepAccOnTiles) - llvm_unreachable("Not yet supported."); + if (candidate.keepAccOnTiles) { + // In this case we have the whole accumulator/result on tiles. Loop + // carried dependencies are not in place yet and should be added. + // After the loop, resulting tiles should either be stored to the + // output buffer, or moved to a vector though a temporary buffer. + + // We don't need the original accumulator and contraction op anymore. + // Directly yield orig accumulator value, so it would be later removed + // as unused. The original contraction can be removed right away. + int64_t origResIdx = op.getResult().getUses().begin()->getOperandNumber(); + rewriter.replaceOp(op, op.getC()); + + // Now, replace the loop with a new one to add loop carried dependency for + // accumulator tiles. + LDBG("Rewrite loop to introduce loop carried dependencies for accumulator " + "tiles."); + SmallVector newInitOperands; + SmallVector newYieldedValues; + for (int64_t m = 0; m < candidate.tilesInBlockM; ++m) + for (int64_t n = 0; n < candidate.tilesInBlockN; ++n) { + LDBG("Initial value\n " << accInitTiles[m][n] + << "\nis combined with\n " << accTiles[m][n]); + newInitOperands.push_back(accInitTiles[m][n]); + newYieldedValues.push_back(accTiles[m][n]); + } + auto newForOp = cast(*forOp.replaceWithAdditionalYields( + rewriter, newInitOperands, true, + [&newYieldedValues](OpBuilder &b, Location loc, + ArrayRef newBBArgs) { + return newYieldedValues; + })); + + // The resulting tiles are now in the new loop results. + auto resTiles = newForOp.getResults().take_back(newYieldedValues.size()); + for (int64_t m = 0; m < candidate.tilesInBlockM; ++m) + for (int64_t n = 0; n < candidate.tilesInBlockN; ++n) { + accTiles[m][n] = resTiles[m * candidate.tilesInBlockN + n]; + } - if (candidate.keepAccInBuf) { - int resIdx = op.getResult().getUses().begin()->getOperandNumber(); - Value loopRes = forOp.getResult(resIdx); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(newForOp); + if (candidate.outBuf.empty()) { + // Move tiles to a vector through a temporary buffer and use it instead + // of the original one. + LDBG("Moving resulting tiles to a vector through memory."); + VectorType resTy = accTy.cloneWith(std::nullopt, candidate.accTileElemTy); + storeBlockTiles(loc, accTileTy, resBuf, 0, 0, accTiles, rewriter); + Value newVal = op_read(resTy, resBuf.memRef, resBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); + rewriter.replaceAllUsesWith(newForOp.getResult(origResIdx), newVal); + } else { + // Store tiles directly to the output buffer and remove the original + // store. + LDBG("Storing resulting tiles to the output memory."); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(candidate.origStore); + storeBlockTiles(loc, accTileTy, candidate.outBuf, 0, 0, accTiles, + rewriter); + rewriter.eraseOp(candidate.origStore); + } + } else if (candidate.keepAccInBuf) { + // The result is in the buffer. We should load it and replace one of the + // loop results. The original contraction op can be removed. + // TODO: should we try to store to the output buffer on the last iteration? + Value loopRes = forOp.getTiedLoopResult(cast(op.getC())); LDBG( "Loading buffererized accumulator to a vector to replace loop result."); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(forOp); - Value newVal = rewriter.create( - loc, cast(acc.getType()), resBuf.memRef, resBuf.indices); + Value newVal = + op_read(cast(acc.getType()), resBuf.memRef, resBuf.indices); // We might need to cast back to the original type. newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); rewriter.replaceAllUsesWith(loopRes, newVal); - // For now, just use init value for unused ForOp result instead of - // its removal. + // Directly yield orig accumulator iter value. It will be removed as unused + // later. rewriter.replaceOp(op, op.getC()); } else if (candidate.outBuf.empty()) { + // The result is in the buffer. We should load it and replace the original + // constraction result. LDBG("Loading the result to a vector to replace orig op result."); - Value newVal = rewriter.create( - loc, cast(acc.getType()), resBuf.memRef, resBuf.indices); + Value newVal = + op_read(cast(acc.getType()), resBuf.memRef, resBuf.indices); // We might need to cast back to the original type. newVal = maybeCast(loc, newVal, accTy.getElementType(), rewriter); rewriter.replaceOp(op, newVal); } else { + // The result is already in the output buffer. We just need to remove the + // original contraction and store operation. LDBG("Removing original operation and its use."); - rewriter.eraseOp(*op.getResult().user_begin()); + rewriter.eraseOp(candidate.origStore); rewriter.eraseOp(op); } From 220b95ab5cdaf9b5fc01bd30b786f7237cbeaab5 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Tue, 17 Dec 2024 12:23:18 -0600 Subject: [PATCH 155/165] Fix extra-store in matmul tutorial. (#198) Signed-off-by: Ilya Enkovich --- python/tutorials/03-matrix-multiplication-cpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tutorials/03-matrix-multiplication-cpu.py b/python/tutorials/03-matrix-multiplication-cpu.py index 3b44a30bf7ad..f14bed73a66e 100644 --- a/python/tutorials/03-matrix-multiplication-cpu.py +++ b/python/tutorials/03-matrix-multiplication-cpu.py @@ -273,7 +273,7 @@ def matmul_kernel( offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_tile_ptr = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - tl.store(c_tile_ptr, c) + tl.store(c_tile_ptr, c) # %% From 485d7092b38dde6df3f5748b05f123d3421657f2 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Wed, 18 Dec 2024 21:29:43 -0600 Subject: [PATCH 156/165] Remove unnecessary bounds checks. (#199) Signed-off-by: Ilya Enkovich --- third_party/cpu/include/TritonCPUTransforms/OptCommon.h | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h index 09e7ec65595d..c3fe3973ce0b 100644 --- a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h +++ b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h @@ -149,9 +149,12 @@ inline Value shapeCast(Location loc, Value in, #define op_sitofp(ty, val) rewriter.create(loc, ty, val) #define op_fptosi(ty, val) rewriter.create(loc, ty, val) #define op_read(ty, memRef, indices) \ - rewriter.create(loc, ty, memRef, indices) + rewriter.create( \ + loc, ty, memRef, indices, SmallVector(ty.getRank(), true)) #define op_write(val, memRef, indices) \ - rewriter.create(loc, val, memRef, indices) + rewriter.create( \ + loc, val, memRef, indices, \ + SmallVector(cast(val.getType()).getRank(), true)) #define op_interleave(lhs, rhs) \ rewriter.create(loc, lhs, rhs) #define op_extract(vec, idx) rewriter.create(loc, vec, idx) From 561c962a4009ad65dcadeec0bc4aeece1744261e Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Sat, 21 Dec 2024 02:30:18 -0600 Subject: [PATCH 157/165] Enable armv8 CI (#195) * [Setup] Skip hatchet pip package for now This does not exist for Darwin + Arm64. TODO: Enable this selectively when possible. * [CPU][driver] Skip non-existent sys paths * [mac-arm64] Add GH CI support - look into faster triton install - enable bf16 tests - enable openmp --- .github/workflows/build-test.yml | 63 +++++++++++++++++++++++++++---- python/setup.py | 2 +- third_party/cpu/backend/driver.py | 14 ++++++- 3 files changed, 68 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index a5178e8f34c8..805c6b8dc7b0 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -45,14 +45,14 @@ jobs: python3 -m pre_commit run --show-diff-on-failure --color=always --all-files --verbose build-test: - name: Build and test - runs-on: - - glados - - intel - - x86 + name: Build and test on ${{ matrix.config.runner }} + runs-on: ${{ matrix.config.runs_on }} strategy: matrix: python: ['3.11'] + config: + - {runner: 'Ubuntu Intel x86', runs_on: ['glados', 'intel', 'x86'], target-os: 'ubuntu', arch: 'x86'} + - {runner: 'MacOS-latest ARM64', runs_on: ['macos-latest'], target-os: 'macos', arch: 'arm64'} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -65,11 +65,16 @@ jobs: python-version: ${{ matrix.python }} - name: Install pip and apt dependencies + env: + RUNNER_TARGET_OS: ${{ matrix.config.target-os }} run: | + echo "RUNNER_TARGET_OS: ${RUNNER_TARGET_OS}" python3 -m pip install --upgrade pip python3 -m pip install wheel cmake==3.24 ninja pytest-xdist lit pybind11 - sudo apt-get update - sudo apt-get install -y zlib1g-dev g++ + if [[ "${RUNNER_TARGET_OS}" == "ubuntu" ]]; then + sudo apt-get update + sudo apt-get install -y zlib1g-dev g++ + fi pip install torch==2.1.2 - name: Install Triton @@ -78,7 +83,49 @@ jobs: cd python python3 -m pip install --no-build-isolation -vvv '.[tests]' - - name: Run python unit tests + - name: Run python unit tests for MacOS Arm64 + if: matrix.config.target-os == 'macos' + run: | + export CC=$(which clang) + export TRITON_DISABLE_OPENMP=1 # temporary + export TRITON_CPU_BACKEND=1 + + # Document some versions/flags + echo "xcode-select:"; xcode-select -p + echo "CC: ${CC}" + clang --version + echo "TRITON_DISABLE_OPENMP=${TRITON_DISABLE_OPENMP}" + echo "TRITON_CPU_BACKEND=${TRITON_CPU_BACKEND}" + + # Skip bfloat16 tests for now + # We are generating bfcvt for bfloat16 tests when converting to fp32. + # This is only for Clang15, works OK for Clang16 + # TODO - fix this using driver flags. + python -m pytest -s -n 32 --device cpu \ + python/test/unit/language/test_core.py -m cpu -k "not bfloat16" + python -m pytest -s -n 32 --device cpu \ + python/test/unit/cpu/test_math.py \ + python/test/unit/cpu/test_opt.py \ + python/test/unit/language/test_annotations.py \ + python/test/unit/language/test_block_pointer.py \ + python/test/unit/language/test_compile_errors.py \ + python/test/unit/language/test_conversions.py \ + python/test/unit/language/test_decorator.py \ + python/test/unit/language/test_pipeliner.py \ + python/test/unit/language/test_random.py \ + python/test/unit/language/test_standard.py \ + python/test/unit/runtime/test_autotuner.py \ + python/test/unit/runtime/test_bindings.py \ + python/test/unit/runtime/test_cache.py \ + python/test/unit/runtime/test_driver.py \ + python/test/unit/runtime/test_jit.py \ + python/test/unit/runtime/test_launch.py \ + python/test/unit/runtime/test_subproc.py \ + python/test/unit/test_debug_dump.py \ + -k "not bfloat16" + + - name: Run python unit tests for Intel + if: matrix.config.target-os == 'ubuntu' run: | python -m pytest -s -n 32 --device cpu python/test/unit/language/test_core.py -m cpu python -m pytest -s -n 32 --device cpu \ diff --git a/python/setup.py b/python/setup.py index 6fcea0acc354..37225c944b0c 100644 --- a/python/setup.py +++ b/python/setup.py @@ -789,7 +789,7 @@ def get_git_version_suffix(): "pytest-forked", "pytest-xdist", "scipy>=1.7.1", - "llnl-hatchet", + # "llnl-hatchet", # TODO: Re-enable this, not available on macos-arm64 ], "tutorials": [ "matplotlib", diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index cadb76e1229a..3308fd23c680 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -12,6 +12,7 @@ from triton.backends.driver import DriverBase from triton.backends.compiler import GPUTarget +from pathlib import Path from triton._C.libtriton import llvm _dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local") @@ -22,10 +23,19 @@ # resources.files() doesn't exist for Python < 3.9 _triton_C_dir = importlib.resources.path(triton, "_C").__enter__() -include_dirs = [os.path.join(_dirname, "include")] -library_dirs = [os.path.join(_dirname, "lib"), _triton_C_dir] +include_dirs = [] +library_dirs = [_triton_C_dir] libraries = ["stdc++"] +# Skip non-existent paths +sys_include_dir = os.path.join(_dirname, "include") +if os.path.exists(sys_include_dir): + include_dirs.append(sys_include_dir) + +sys_lib_dir = os.path.join(_dirname, "lib") +if os.path.exists(sys_lib_dir): + library_dirs.append(sys_lib_dir) + def compile_module_from_src(src, name): key = hashlib.md5(src.encode("utf-8")).hexdigest() From 5b430eedb191760d9c3e1d030078d743740c66d5 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 23 Dec 2024 13:05:22 -0600 Subject: [PATCH 158/165] Fix isSigned usage for scalar prints. (#201) Signed-off-by: Ilya Enkovich --- .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index 21b4756b506e..33e1753e31b2 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -32,8 +32,10 @@ class TritonLLVMConversionTarget : public ConversionTarget { // TODO: This code is the same as the GPU-backend code. Consider refactoring. std::string getFormatSubstr(Value value, bool hex = false, - std::optional width = std::nullopt) { + std::optional width = std::nullopt, + bool isSigned = false) { Type type = value.getType(); + // If the `value` is a pointer, just return %p. if (isa(type)) { return "%p"; } @@ -52,23 +54,15 @@ std::string getFormatSubstr(Value value, bool hex = false, std::string prefix = "%"; if (width.has_value()) { prefix += std::to_string(*width); - } else if (hex) { - prefix += "0"; - prefix += std::to_string(value.getType().getIntOrFloatBitWidth() / 4); } if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) { return prefix + "f"; - } else if (type.isSignedInteger()) { + } else if (type.isInteger()) { if (type.getIntOrFloatBitWidth() == 64) - return prefix + "lli"; + return prefix + (isSigned ? "lli" : "llu"); else - return prefix + "i"; - } else if (type.isUnsignedInteger() || type.isSignlessInteger()) { - if (type.getIntOrFloatBitWidth() == 64) - return prefix + "llu"; - else - return prefix + "u"; + return prefix + (isSigned ? "i" : "u"); } assert(false && "not supported type"); return ""; @@ -163,7 +157,8 @@ static StringRef makeNullTerminatedString(StringRef s) { void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, std::array pid, StringRef prefix, - std::optional arg, bool hex = false) { + std::optional arg, bool hex = false, + bool isSigned = false) { assert(!prefix.empty() && "printf with empty string not supported"); auto loc = UnknownLoc::get(rewriter.getContext()); @@ -172,7 +167,7 @@ void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, os << "(" << getFormatSubstr(pid[0]) << ", " << getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")" << prefix; if (arg.has_value()) - os << getFormatSubstr(arg.value(), hex); + os << getFormatSubstr(arg.value(), hex, std::nullopt, isSigned); llvm::SmallString<64> formatStrNewline(formatStr); formatStrNewline.push_back('\n'); @@ -242,7 +237,8 @@ struct PrintOpConversion : public ConvertOpToLLVMPattern { std::nullopt); } else { createRuntimePrintScalarCall(rewriter, pid, op.getPrefix(), - adaptor.getOperands()[0], op.getHex()); + adaptor.getOperands()[0], op.getHex(), + op.getIsSigned()[0]); } rewriter.eraseOp(op); return success(); From 5846858b622a0f69ce25bcd760aaa5ffd200bc55 Mon Sep 17 00:00:00 2001 From: Ilya Enkovich Date: Mon, 13 Jan 2025 16:23:47 -0600 Subject: [PATCH 159/165] Support VNNI pre-encoded input in AMX lowering. (#210) Signed-off-by: Ilya Enkovich --- ...d-matmul-fp32.py => cpu-blocked-matmul.py} | 137 +++++++---- test/TritonCPU/dot-to-amx.mlir | 223 ++++++++++++++++++ third_party/cpu/language/cpu/__init__.py | 3 + third_party/cpu/language/cpu/utils.py | 22 ++ .../ConvertDotOp/ConvertDotCommon.cpp | 73 +++++- .../ConvertDotOp/ConvertDotCommon.h | 20 +- .../ConvertDotOp/ConvertDotToAMX.cpp | 11 +- .../TritonToTritonCPU/ConvertElemManipOps.cpp | 21 +- 8 files changed, 435 insertions(+), 75 deletions(-) rename python/tutorials/{cpu-blocked-matmul-fp32.py => cpu-blocked-matmul.py} (73%) create mode 100644 third_party/cpu/language/cpu/utils.py diff --git a/python/tutorials/cpu-blocked-matmul-fp32.py b/python/tutorials/cpu-blocked-matmul.py similarity index 73% rename from python/tutorials/cpu-blocked-matmul-fp32.py rename to python/tutorials/cpu-blocked-matmul.py index 8f0f0ebce41a..e8f274d6c552 100644 --- a/python/tutorials/cpu-blocked-matmul-fp32.py +++ b/python/tutorials/cpu-blocked-matmul.py @@ -18,11 +18,13 @@ import os DTYPE = os.getenv("DTYPE", "float32") +in_dtype = getattr(torch, DTYPE) +out_dtype = torch.float32 if in_dtype.is_floating_point else torch.int32 # Choose block size depending on dtype. We have more register # capacity for bfloat16/float16 compared to float32. BLOCK_SIZE_M = 8 if DTYPE == "float32" else 32 BLOCK_SIZE_N = 32 -BLOCK_SIZE_K = 8 if DTYPE == "float32" else 32 +BLOCK_SIZE_K = 8 if DTYPE == "float32" else 64 // in_dtype.itemsize GROUP_SIZE_M = 8 @@ -38,6 +40,9 @@ # tensor are transposed. It provides contiguos placement for a column # of blocks. # +# If PACKED_B is set to True then B is VNNI encoded. Only works when +# BLOCKED_B is True. +# # If TRANSPOSED_BLOCK_A is set to True then tail dimensions of the LHS # tensor are transposed. Transposed LHS block better matches FMA lowering # used by Triton CPU backend which processes RHS block row-by-row and LHS @@ -46,7 +51,7 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, BLOCKED_B: tl.constexpr, - TRANSPOSED_B: tl.constexpr): + TRANSPOSED_B: tl.constexpr, PACKED_B: tl.constexpr): tl.static_assert(BLOCKED_A or not TRANSPOSED_BLOCK_A) tl.static_assert(BLOCKED_B or not TRANSPOSED_B) pid = tl.program_id(axis=0) @@ -85,9 +90,11 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZ tl.store(a_out_ptr, val) if BLOCKED_B: + B_PACKED_NUM: tl.constexpr = 32 // in_b.type.element_ty.primitive_bitwidth if PACKED_B else 1 + PACKED_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_K // B_PACKED_NUM if PACKED_B else BLOCK_SIZE_K + PACKED_BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_N * B_PACKED_NUM if PACKED_B else BLOCK_SIZE_N B_OUT_BLOCKS_K = N // BLOCK_SIZE_N if TRANSPOSED_B else K // BLOCK_SIZE_K B_OUT_BLOCKS_N = K // BLOCK_SIZE_K if TRANSPOSED_B else N // BLOCK_SIZE_N - B_OUT_STRIDE_K: tl.constexpr = BLOCK_SIZE_N B_OUT_STRIDE_BLOCK_K = (K * BLOCK_SIZE_N if TRANSPOSED_B else BLOCK_SIZE_K * N) B_OUT_STRIDE_BLOCK_N: tl.constexpr = BLOCK_SIZE_K * BLOCK_SIZE_N for in_block_k in tl.range(in_block_m, K // BLOCK_SIZE_K, M // BLOCK_SIZE_M): @@ -95,15 +102,28 @@ def block_transpose_combined_kernel(in_a, out_a, in_b, out_b, M, N, K, BLOCK_SIZ b_out_block_n = in_block_k if TRANSPOSED_B else in_block_n b_in_ptr = tl.make_block_ptr(base=in_b, shape=(K, N), strides=(N, 1), offsets=(in_block_k * BLOCK_SIZE_K, in_block_n * BLOCK_SIZE_N), - block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(1, 0)) - b_out_ptr = tl.make_block_ptr(base=out_b, - shape=(B_OUT_BLOCKS_K, B_OUT_BLOCKS_N, BLOCK_SIZE_K, BLOCK_SIZE_N), - strides=(B_OUT_STRIDE_BLOCK_K, B_OUT_STRIDE_BLOCK_N, B_OUT_STRIDE_K, 1), - offsets=(b_out_block_k, b_out_block_n, 0, 0), - block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), order=(3, 2, 1, 0)) - val = tl.load(b_in_ptr) - val = tl.reshape(val, (1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N)) - tl.store(b_out_ptr, val) + block_shape=(1, BLOCK_SIZE_N), order=(1, 0)) + b_out_ptr = tl.make_block_ptr( + base=out_b, shape=(B_OUT_BLOCKS_K, B_OUT_BLOCKS_N, PACKED_BLOCK_SIZE_K, PACKED_BLOCK_SIZE_N), + strides=(B_OUT_STRIDE_BLOCK_K, B_OUT_STRIDE_BLOCK_N, PACKED_BLOCK_SIZE_N, 1), + offsets=(b_out_block_k, b_out_block_n, 0, 0), block_shape=(1, 1, 1, PACKED_BLOCK_SIZE_N), + order=(3, 2, 1, 0)) + for i in tl.range(0, BLOCK_SIZE_K // B_PACKED_NUM): + row1 = tl.load(b_in_ptr).reshape((BLOCK_SIZE_N, )) + if B_PACKED_NUM > 1: + b_in_ptr = tl.advance(b_in_ptr, (1, 0)) + row2 = tl.load(b_in_ptr).reshape((BLOCK_SIZE_N, )) + if B_PACKED_NUM > 2: + b_in_ptr = tl.advance(b_in_ptr, (1, 0)) + row3 = tl.load(b_in_ptr).reshape((BLOCK_SIZE_N, )) + b_in_ptr = tl.advance(b_in_ptr, (1, 0)) + row4 = tl.load(b_in_ptr).reshape((BLOCK_SIZE_N, )) + row1 = tl.ravel(tl.join(row1, row3)) + row2 = tl.ravel(tl.join(row2, row4)) + row1 = tl.ravel(tl.join(row1, row2)) + tl.store(b_out_ptr, row1.reshape((1, 1, 1, PACKED_BLOCK_SIZE_N))) + b_in_ptr = tl.advance(b_in_ptr, (1, 0)) + b_out_ptr = tl.advance(b_out_ptr, (0, 0, 1, 0)) # Matmul kernel that computes a single output block [BLOCK_SIZE_M, BLOCK_SIZE_N]. LHS can be in the @@ -125,7 +145,7 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOC BLOCK_SIZE_K: tl.constexpr, # number of blocks in a group GROUP_SIZE_M: tl.constexpr, BLOCKED_A: tl.constexpr, TRANSPOSED_BLOCK_A: tl.constexpr, - BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr): + BLOCKED_B: tl.constexpr, TRANSPOSED_B: tl.constexpr, PACKED_B: tl.constexpr, OUT_DTYPE: tl.constexpr): # TRANSPOSED_BLOCK_A means that each block in A is transposed. # It is allowed only for blocked input. assert (BLOCKED_A or not TRANSPOSED_BLOCK_A) @@ -151,37 +171,43 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOC a_stride_block_k = A_BLOCK_SIZE_M * A_BLOCK_SIZE_K if BLOCKED_A else A_BLOCK_SIZE_K a_stride_block_m = BLOCK_SIZE_M * K + B_PACKED_NUM: tl.constexpr = 32 // b_ptr.type.element_ty.primitive_bitwidth if PACKED_B else 1 + PACKED_BLOCK_SIZE_K: tl.constexpr = BLOCK_SIZE_K // B_PACKED_NUM if PACKED_B else BLOCK_SIZE_K + PACKED_BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_N * B_PACKED_NUM if PACKED_B else BLOCK_SIZE_N + assert BLOCKED_B or not TRANSPOSED_B b_stride_n = 1 - b_stride_k = BLOCK_SIZE_N if BLOCKED_B else N + b_stride_k = PACKED_BLOCK_SIZE_N if BLOCKED_B else N * B_PACKED_NUM if TRANSPOSED_B: b_stride_block_n = BLOCK_SIZE_N * K b_stride_block_k = BLOCK_SIZE_K * BLOCK_SIZE_N else: - b_stride_block_n = BLOCK_SIZE_K * BLOCK_SIZE_N if BLOCKED_B else BLOCK_SIZE_N + b_stride_block_n = BLOCK_SIZE_K * BLOCK_SIZE_N if BLOCKED_B else PACKED_BLOCK_SIZE_N b_stride_block_k = BLOCK_SIZE_K * N a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(A_BLOCKS_M, A_BLOCKS_K, A_BLOCK_SIZE_M, A_BLOCK_SIZE_K), strides=(a_stride_block_m, a_stride_block_k, a_stride_m, a_stride_k), offsets=(block_m, 0, 0, 0), block_shape=(1, 1, A_BLOCK_SIZE_M, A_BLOCK_SIZE_K), order=(3, 2, 1, 0)) - b_block_ptr = tl.make_block_ptr(base=b_ptr, - shape=(K // BLOCK_SIZE_K, N // BLOCK_SIZE_N, BLOCK_SIZE_K, BLOCK_SIZE_N), - strides=(b_stride_block_k, b_stride_block_n, b_stride_k, b_stride_n), - offsets=(0, block_n, 0, 0), block_shape=(1, 1, BLOCK_SIZE_K, BLOCK_SIZE_N), - order=(3, 2, 1, 0)) + b_block_ptr = tl.make_block_ptr( + base=b_ptr, shape=(K // BLOCK_SIZE_K, N // BLOCK_SIZE_N, PACKED_BLOCK_SIZE_K, PACKED_BLOCK_SIZE_N), + strides=(b_stride_block_k, b_stride_block_n, b_stride_k, b_stride_n), offsets=(0, block_n, 0, 0), + block_shape=(1, 1, PACKED_BLOCK_SIZE_K, PACKED_BLOCK_SIZE_N), order=(3, 2, 1, 0)) c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(N, 1), offsets=(block_m * BLOCK_SIZE_M, block_n * BLOCK_SIZE_N), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) - c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=OUT_DTYPE) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_block_ptr).reshape((A_BLOCK_SIZE_M, A_BLOCK_SIZE_K)) - b = tl.load(b_block_ptr).reshape((BLOCK_SIZE_K, BLOCK_SIZE_N)) + b = tl.load(b_block_ptr).reshape((PACKED_BLOCK_SIZE_K, PACKED_BLOCK_SIZE_N)) if TRANSPOSED_BLOCK_A: a = a.T - c += tl.dot(a, b, out_dtype=tl.float32) + if PACKED_B: + b = tl.extra.cpu.vnni_decode(b) + + c += tl.dot(a, b, out_dtype=OUT_DTYPE) a_block_ptr = tl.advance(a_block_ptr, (0, 1, 0, 0)) b_block_ptr = tl.advance(b_block_ptr, (1, 0, 0, 0)) @@ -190,7 +216,7 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOC def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, PREPACKED, - BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, num_threads=0): + BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, PACKED_B, num_threads=0): #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" @@ -203,7 +229,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # BLOCKED_A=BLOCKED_A, TRANSPOSED_BLOCK_A=TRANSPOSED_BLOCK_A, # - BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B) + BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B, PACKED_B=PACKED_B) if BLOCKED_A: a = ab if BLOCKED_B: @@ -214,7 +240,8 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # BLOCKED_A=BLOCKED_A, TRANSPOSED_BLOCK_A=TRANSPOSED_BLOCK_A, # - BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B, num_threads=num_threads) + BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B, PACKED_B=PACKED_B, # + OUT_DTYPE=tl.float32 if a.dtype.is_floating_point else tl.int32, num_threads=num_threads) return c @@ -227,13 +254,17 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, triton.runtime.driver.set_active_to_cpu() -a = torch.randn((512, 512), device='cpu', dtype=torch.float32) -b = torch.randn((512, 512), device='cpu', dtype=torch.float32) -c = torch.empty((512, 512), device='cpu', dtype=torch.float32) -torch_output = torch.matmul(a, b) +if in_dtype.is_floating_point: + a = torch.randn((512, 512), device='cpu', dtype=in_dtype) + b = torch.randn((512, 512), device='cpu', dtype=in_dtype) +else: + a = torch.randint(0, 5, (512, 512), device='cpu', dtype=in_dtype) + b = torch.randint(0, 5, (512, 512), device='cpu', dtype=in_dtype) +c = torch.empty((512, 512), device='cpu', dtype=out_dtype) +torch_output = torch.matmul(a.to(out_dtype), b.to(out_dtype)) rtol = 0 -a_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_M) * (512 // BLOCK_SIZE_K) * 64), device='cpu', dtype=torch.float32) -b_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_K) * (512 // BLOCK_SIZE_N) * 64), device='cpu', dtype=torch.float32) +a_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_M) * (512 // BLOCK_SIZE_K) * 64), device='cpu', dtype=in_dtype) +b_tmp = torch.zeros((512 * 512 + (512 // BLOCK_SIZE_K) * (512 // BLOCK_SIZE_N) * 64), device='cpu', dtype=in_dtype) triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, True, False, False, False, False, False) if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU and TorchCPU match") @@ -241,7 +272,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, print("❌ TritonCPU and TorchCPU differ, the maximum difference is " f'{torch.max(torch.abs(triton_output - torch_output))}') assert False -triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, True) +triton_output = matmul(a, b, c, a_tmp, b_tmp, 512, 512, 512, False, True, True, True, True, DTYPE != "float32") if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): print("✅ TritonCPU pre-packed and TorchCPU match") else: @@ -260,13 +291,13 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, # but feel free to arrange this script as you wish to benchmark any other matrix shape. -def encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype): - assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' - return f"triton-cpu{'-ba' if blocked_a else ''}{'-ta' if transposed_a else ''}{'-bb' if blocked_b else ''}{'-tb' if transposed_b else ''}{'-prepack' if prepack else ''}{'-st' if single_thread else ''}-{dtype}" +def encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype): + assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' or dtype == 'int8' + return f"triton-cpu{'-ba' if blocked_a else ''}{'-ta' if transposed_a else ''}{'-bb' if blocked_b else ''}{'-tb' if transposed_b else ''}{'-pb' if packed_b else ''}{'-prepack' if prepack else ''}{'-st' if single_thread else ''}-{dtype}" def encode_torch_provider(single_thread, dtype): - assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' + assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' or dtype == 'int8' return f"torch-cpu-native{'-st' if single_thread else ''}-{dtype}" @@ -277,28 +308,30 @@ def decode_provider(provider): dtype = torch.float16 elif '-float32' in provider: dtype = torch.float32 + elif '-int8' in provider: + dtype = torch.int8 if 'triton-cpu' in provider: backend = 'triton-cpu' elif 'torch-cpu-native' in provider: backend = 'torch-cpu-native' elif 'torch-cpu-compile' in provider: backend = 'torch-cpu-compile' - return backend, '-ba' in provider, '-ta' in provider, '-bb' in provider, '-tb' in provider, '-prepack' in provider, '-st' in provider, dtype + return backend, '-ba' in provider, '-ta' in provider, '-bb' in provider, '-tb' in provider, '-pb' in provider, '-prepack' in provider, '-st' in provider, dtype BLOCK_TRANSPOSE_A_OPTS = [(False, False)] -BLOCK_TRANSPOSE_B_OPTS = [(True, True), (False, False)] +BLOCK_TRANSPOSE_PACK_B_OPTS = [(True, True, True), (True, True, False), (False, False, False)] PREPACK_OPTS = [False, True] SINGLE_THREAD_OPTS = [False] DTYPE_OPTS = [DTYPE] LINE_VALS = [ - encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype) + encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype) for single_thread in SINGLE_THREAD_OPTS for blocked_a, transposed_a in BLOCK_TRANSPOSE_A_OPTS - for blocked_b, transposed_b in BLOCK_TRANSPOSE_B_OPTS + for blocked_b, transposed_b, packed_b in BLOCK_TRANSPOSE_PACK_B_OPTS for prepack in PREPACK_OPTS for dtype in DTYPE_OPTS - if blocked_a or blocked_b or not prepack + if (blocked_a or blocked_b or not prepack) and (not packed_b or dtype != "float32") ] + [encode_torch_provider(single_thread, dtype) for dtype in DTYPE_OPTS for single_thread in SINGLE_THREAD_OPTS] LINE_NAMES = LINE_VALS LINE_STYLES = None @@ -323,9 +356,14 @@ def decode_provider(provider): def benchmark(M, N, K, provider): device = 'cpu' if 'cpu' in provider else 'cuda' - backend, blocked_a, transposed_a, blocked_b, transposed_b, prepack, single_thread, dtype = decode_provider(provider) - a = torch.randn((M, K), device=device, dtype=dtype) - b = torch.randn((K, N), device=device, dtype=dtype) + backend, blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype = decode_provider( + provider) + if dtype.is_floating_point: + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((K, N), device=device, dtype=dtype) + else: + a = torch.randint(0, 5, (M, K), device=device, dtype=dtype) + b = torch.randint(0, 5, (K, N), device=device, dtype=dtype) if single_thread: torch.set_num_threads(1) @@ -333,10 +371,10 @@ def benchmark(M, N, K, provider): torch.set_num_threads(default_num_threads) if backend == 'triton-cpu': - c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + c = torch.zeros((M, N), device=a.device, dtype=out_dtype) a_tmp = torch.zeros((M * K + (M // BLOCK_SIZE_M) * (K // BLOCK_SIZE_K) * 64), device=device, dtype=dtype) b_tmp = torch.zeros((K * N + (K // BLOCK_SIZE_K) * (N // BLOCK_SIZE_N) * 64), device=device, dtype=dtype) - c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + c = torch.zeros((M, N), device=a.device, dtype=out_dtype) if prepack and (blocked_a or blocked_b): grid = ((M // BLOCK_SIZE_M) * (N // BLOCK_SIZE_N), ) block_transpose_combined_kernel[grid]( @@ -345,7 +383,7 @@ def benchmark(M, N, K, provider): BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, # GROUP_SIZE_M=GROUP_SIZE_M, # BLOCKED_A=blocked_a, TRANSPOSED_BLOCK_A=transposed_a, # - BLOCKED_B=blocked_b, TRANSPOSED_B=transposed_b) + BLOCKED_B=blocked_b, TRANSPOSED_B=transposed_b, PACKED_B=packed_b) if blocked_a: a = a_tmp if blocked_b: @@ -362,7 +400,8 @@ def benchmark(M, N, K, provider): elif backend == 'triton-cpu': ms, min_ms, max_ms = triton.testing.do_bench( lambda: matmul(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, transposed_b, - num_threads=int(single_thread)), quantiles=quantiles, measure_time_with_hooks=True, rep=1000) + packed_b, num_threads=int(single_thread)), quantiles=quantiles, measure_time_with_hooks=True, + rep=1000) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/test/TritonCPU/dot-to-amx.mlir b/test/TritonCPU/dot-to-amx.mlir index da501849f723..5b6a020306e6 100644 --- a/test/TritonCPU/dot-to-amx.mlir +++ b/test/TritonCPU/dot-to-amx.mlir @@ -236,3 +236,226 @@ module { tt.return loc(#loc) } loc(#loc) } loc(#loc) + +// ----- + +// A case with VNNI pre-encoded RHS that can be directly accessed from the input memory. +// We expect both LHS and RHS tiles to be directly loaded from the input mmemory. + +// CHECK-LABEL: @test_loop_pre_encoded_direct +// CHECK: %[[LHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[LHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: %[[RHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[RHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: amx.tile_load %[[LHS_MEMREF]][%[[LHS_INDICES]]#0, %[[LHS_INDICES]]#1, %[[LHS_INDICES]]#2, %[[LHS_INDICES]]#3] +// CHECK: amx.tile_load %[[RHS_MEMREF]][%[[RHS_INDICES]]#0, %[[RHS_INDICES]]#1, %[[RHS_INDICES]]#2, %[[RHS_INDICES]]#3] +#loc = loc(unknown) +module { + tt.func public @test_loop_pre_encoded_direct(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %c31_i32 = arith.constant 31 : i32 loc(#loc) + %c1024_i64 = arith.constant 1024 : i64 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x32xf32> loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c32_i32 = arith.constant 32 : i32 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %0 = tt.get_program_id x : i32 loc(#loc) + %1 = arith.divsi %arg4, %c32_i32 : i32 loc(#loc) + %2 = arith.divsi %0, %1 : i32 loc(#loc) + %3 = arith.remsi %0, %1 : i32 loc(#loc) + %4 = arith.muli %arg5, %c32_i32 : i32 loc(#loc) + %5 = arith.divsi %arg3, %c32_i32 : i32 loc(#loc) + %6 = arith.divsi %arg5, %c32_i32 : i32 loc(#loc) + %7 = arith.extsi %5 : i32 to i64 loc(#loc) + %8 = arith.extsi %6 : i32 to i64 loc(#loc) + %9 = arith.extsi %4 : i32 to i64 loc(#loc) + %10 = arith.extsi %arg5 : i32 to i64 loc(#loc) + %11 = tt.make_tensor_ptr %arg0, [%7, %8, %c32_i64, %c32_i64], [%9, %c32_i64, %10, %c1_i64], [%2, %c0_i32, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %12 = arith.extsi %1 : i32 to i64 loc(#loc) + %13 = tt.make_tensor_ptr %arg1, [%8, %12, %c16_i64, %c64_i64], [%c1024_i64, %9, %c64_i64, %c1_i64], [%c0_i32, %3, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %14 = arith.muli %2, %c32_i32 : i32 loc(#loc) + %15 = arith.muli %3, %c32_i32 : i32 loc(#loc) + %16 = arith.extsi %arg3 : i32 to i64 loc(#loc) + %17 = arith.extsi %arg4 : i32 to i64 loc(#loc) + %18 = tt.make_tensor_ptr %arg2, [%16, %17], [%17, %c1_i64], [%14, %15] {order = array} : > loc(#loc) + %19 = arith.addi %arg5, %c31_i32 : i32 loc(#loc) + %20 = arith.divsi %19, %c32_i32 : i32 loc(#loc) + %21:3 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %cst_0, %arg8 = %11, %arg9 = %13) -> (vector<32x32xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %24 = triton_cpu.extract_memref %arg8 : > -> memref> loc(#loc) + %25:4 = triton_cpu.extract_indices %arg8 : > -> index, index, index, index loc(#loc) + %26 = vector.transfer_read %24[%25#0, %25#1, %25#2, %25#3], %cst {in_bounds = [true, true]} : memref>, vector<32x32xbf16> loc(#loc) + %27 = triton_cpu.extract_memref %arg9 : > -> memref> loc(#loc) + %28:4 = triton_cpu.extract_indices %arg9 : > -> index, index, index, index loc(#loc) + %29 = vector.transfer_read %27[%28#0, %28#1, %28#2, %28#3], %cst {in_bounds = [true, true]} : memref>, vector<16x64xbf16> loc(#loc) + %res1, %res2 = vector.deinterleave %29 : vector<16x64xbf16> -> vector<16x32xbf16> loc(#loc) + %30 = vector.transpose %res1, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %31 = vector.transpose %res2, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %32 = vector.interleave %30, %31 : vector<32x16xbf16> -> vector<32x32xbf16> loc(#loc) + %33 = vector.transpose %32, [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> loc(#loc) + %34 = triton_cpu.dot %26, %33, %arg7, inputPrecision = tf32 : vector<32x32xbf16> * vector<32x32xbf16> -> vector<32x32xf32> loc(#loc) + %35 = tt.advance %arg8, [%c0_i32, %c1_i32, %c0_i32, %c0_i32] : > loc(#loc) + %36 = tt.advance %arg9, [%c1_i32, %c0_i32, %c0_i32, %c0_i32] : > loc(#loc) + scf.yield %34, %35, %36 : vector<32x32xf32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %22 = triton_cpu.extract_memref %18 : > -> memref> loc(#loc) + %23:2 = triton_cpu.extract_indices %18 : > -> index, index loc(#loc) + vector.transfer_write %21#0, %22[%23#0, %23#1] {in_bounds = [true, true]} : vector<32x32xf32>, memref> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// A case with VNNI pre-encoded RHS that cannot be directly accessed from the input memory. +// We expect LHS to be directly loaded from the input mmemory and RHS to be loaded through +// a temporary buffer without additional encoding. + + +// CHECK-LABEL: @test_loop_pre_encoded_indirect +// CHECK: %[[RHS_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<16x64xbf16> +// CHECK: %[[LHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[LHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: %[[RHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[RHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK-NEXT: %[[RHS:.+]] = vector.transfer_read %[[RHS_MEMREF]][%[[RHS_INDICES]]#0, %[[RHS_INDICES]]#1, %[[RHS_INDICES]]#2, %[[RHS_INDICES]]#3] +// CHECK: vector.transfer_write %[[RHS]], %[[RHS_BUF]][%c0, %c0] {in_bounds = [true, true]} +// CHECK: amx.tile_load %[[LHS_MEMREF]][%[[LHS_INDICES]]#0, %[[LHS_INDICES]]#1, %[[LHS_INDICES]]#2, %[[LHS_INDICES]]#3] +// CHECK: amx.tile_load %[[RHS_BUF]][%c0, %c0] +#loc = loc(unknown) +module { + tt.func public @test_loop_pre_encoded_indirect(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %c31_i32 = arith.constant 31 : i32 loc(#loc) + %c1024_i64 = arith.constant 1024 : i64 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x32xf32> loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c32_i32 = arith.constant 32 : i32 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %0 = tt.get_program_id x : i32 loc(#loc) + %1 = arith.divsi %arg4, %c32_i32 : i32 loc(#loc) + %2 = arith.divsi %0, %1 : i32 loc(#loc) + %3 = arith.remsi %0, %1 : i32 loc(#loc) + %4 = arith.muli %arg5, %c32_i32 : i32 loc(#loc) + %5 = arith.divsi %arg3, %c32_i32 : i32 loc(#loc) + %6 = arith.divsi %arg5, %c32_i32 : i32 loc(#loc) + %7 = arith.extsi %5 : i32 to i64 loc(#loc) + %8 = arith.extsi %6 : i32 to i64 loc(#loc) + %9 = arith.extsi %4 : i32 to i64 loc(#loc) + %10 = arith.extsi %arg5 : i32 to i64 loc(#loc) + %11 = tt.make_tensor_ptr %arg0, [%7, %8, %c32_i64, %c32_i64], [%9, %c32_i64, %10, %c1_i64], [%2, %c0_i32, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %12 = arith.extsi %1 : i32 to i64 loc(#loc) + %13 = tt.make_tensor_ptr %arg1, [%8, %12, %c16_i64, %c64_i64], [%c1024_i64, %9, %c64_i64, %c1_i64], [%c0_i32, %3, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %14 = arith.muli %2, %c32_i32 : i32 loc(#loc) + %15 = arith.muli %3, %c32_i32 : i32 loc(#loc) + %16 = arith.extsi %arg3 : i32 to i64 loc(#loc) + %17 = arith.extsi %arg4 : i32 to i64 loc(#loc) + %18 = tt.make_tensor_ptr %arg2, [%16, %17], [%17, %c1_i64], [%14, %15] {order = array} : > loc(#loc) + %19 = arith.addi %arg5, %c31_i32 : i32 loc(#loc) + %20 = arith.divsi %19, %c32_i32 : i32 loc(#loc) + %21:3 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %cst_0, %arg8 = %11, %arg9 = %13) -> (vector<32x32xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %24 = triton_cpu.extract_memref %arg8 : > -> memref> loc(#loc) + %25:4 = triton_cpu.extract_indices %arg8 : > -> index, index, index, index loc(#loc) + %26 = vector.transfer_read %24[%25#0, %25#1, %25#2, %25#3], %cst {in_bounds = [true, true]} : memref>, vector<32x32xbf16> loc(#loc) + %27 = triton_cpu.extract_memref %arg9 : > -> memref> loc(#loc) + %28:4 = triton_cpu.extract_indices %arg9 : > -> index, index, index, index loc(#loc) + %29 = vector.transfer_read %27[%28#0, %28#1, %28#2, %28#3], %cst {in_bounds = [false, false]} : memref>, vector<16x64xbf16> loc(#loc) + %res1, %res2 = vector.deinterleave %29 : vector<16x64xbf16> -> vector<16x32xbf16> loc(#loc) + %30 = vector.transpose %res1, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %31 = vector.transpose %res2, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %32 = vector.interleave %30, %31 : vector<32x16xbf16> -> vector<32x32xbf16> loc(#loc) + %33 = vector.transpose %32, [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> loc(#loc) + %34 = triton_cpu.dot %26, %33, %arg7, inputPrecision = tf32 : vector<32x32xbf16> * vector<32x32xbf16> -> vector<32x32xf32> loc(#loc) + %35 = tt.advance %arg8, [%c0_i32, %c1_i32, %c0_i32, %c0_i32] : > loc(#loc) + %36 = tt.advance %arg9, [%c1_i32, %c0_i32, %c0_i32, %c0_i32] : > loc(#loc) + scf.yield %34, %35, %36 : vector<32x32xf32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %22 = triton_cpu.extract_memref %18 : > -> memref> loc(#loc) + %23:2 = triton_cpu.extract_indices %18 : > -> index, index loc(#loc) + vector.transfer_write %21#0, %22[%23#0, %23#1] {in_bounds = [true, true]} : vector<32x32xf32>, memref> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// A case with int8 VNNI pre-encoded RHS that can be directly accessed from the input memory. +// We expect both LHS and RHS tiles to be directly loaded from the input mmemory. + +// CHECK-LABEL: @test_loop_int8_pre_encoded +// CHECK: %[[LHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[LHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: %[[RHS_MEMREF:.+]] = triton_cpu.extract_memref +// CHECK-NEXT: %[[RHS_INDICES:.+]]:4 = triton_cpu.extract_indices +// CHECK: amx.tile_load %[[LHS_MEMREF]][%[[LHS_INDICES]]#0, %[[LHS_INDICES]]#1, %[[LHS_INDICES]]#2, %[[LHS_INDICES]]#3] +// CHECK: amx.tile_load %[[RHS_MEMREF]][%[[RHS_INDICES]]#0, %[[RHS_INDICES]]#1, %[[RHS_INDICES]]#2, %[[RHS_INDICES]]#3] +#loc = loc(unknown) +module { + tt.func public @test_loop_int8_pre_encoded(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c0_i8 = arith.constant 0 : i8 loc(#loc) + %c31_i32 = arith.constant 31 : i32 loc(#loc) + %c1024_i64 = arith.constant 1024 : i64 loc(#loc) + %cst = arith.constant dense<0> : vector<32x32xi32> loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c128_i64 = arith.constant 128 : i64 loc(#loc) + %c8_i64 = arith.constant 8 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c32_i32 = arith.constant 32 : i32 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %0 = tt.get_program_id x : i32 loc(#loc) + %1 = arith.divsi %arg4, %c32_i32 : i32 loc(#loc) + %2 = arith.divsi %0, %1 : i32 loc(#loc) + %3 = arith.remsi %0, %1 : i32 loc(#loc) + %4 = arith.muli %arg5, %c32_i32 : i32 loc(#loc) + %5 = arith.divsi %arg3, %c32_i32 : i32 loc(#loc) + %6 = arith.divsi %arg5, %c32_i32 : i32 loc(#loc) + %7 = arith.extsi %5 : i32 to i64 loc(#loc) + %8 = arith.extsi %6 : i32 to i64 loc(#loc) + %9 = arith.extsi %4 : i32 to i64 loc(#loc) + %10 = arith.extsi %arg5 : i32 to i64 loc(#loc) + %11 = tt.make_tensor_ptr %arg0, [%7, %8, %c32_i64, %c32_i64], [%9, %c32_i64, %10, %c1_i64], [%2, %c0_i32, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %12 = arith.extsi %1 : i32 to i64 loc(#loc) + %13 = tt.make_tensor_ptr %arg1, [%8, %12, %c8_i64, %c128_i64], [%c1024_i64, %9, %c128_i64, %c1_i64], [%c0_i32, %3, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %14 = arith.muli %2, %c32_i32 : i32 loc(#loc) + %15 = arith.muli %3, %c32_i32 : i32 loc(#loc) + %16 = arith.extsi %arg3 : i32 to i64 loc(#loc) + %17 = arith.extsi %arg4 : i32 to i64 loc(#loc) + %18 = tt.make_tensor_ptr %arg2, [%16, %17], [%17, %c1_i64], [%14, %15] {order = array} : > loc(#loc) + %19 = arith.addi %arg5, %c31_i32 : i32 loc(#loc) + %20 = arith.divsi %19, %c32_i32 : i32 loc(#loc) + %21:3 = scf.for %arg6 = %c0_i32 to %20 step %c1_i32 iter_args(%arg7 = %cst, %arg8 = %11, %arg9 = %13) -> (vector<32x32xi32>, !tt.ptr>, !tt.ptr>) : i32 { + %24 = triton_cpu.extract_memref %arg8 : > -> memref> loc(#loc) + %25:4 = triton_cpu.extract_indices %arg8 : > -> index, index, index, index loc(#loc) + %26 = vector.transfer_read %24[%25#0, %25#1, %25#2, %25#3], %c0_i8 {in_bounds = [true, true]} : memref>, vector<32x32xi8> loc(#loc) + %27 = triton_cpu.extract_memref %arg9 : > -> memref> loc(#loc) + %28:4 = triton_cpu.extract_indices %arg9 : > -> index, index, index, index loc(#loc) + %30 = vector.transfer_read %27[%28#0, %28#1, %28#2, %28#3], %c0_i8 {in_bounds = [true, true]} : memref>, vector<8x128xi8> loc(#loc) + %res1, %res2 = vector.deinterleave %30 : vector<8x128xi8> -> vector<8x64xi8> loc(#loc) + %31 = vector.transpose %res1, [1, 0] : vector<8x64xi8> to vector<64x8xi8> loc(#loc) + %32 = vector.transpose %res2, [1, 0] : vector<8x64xi8> to vector<64x8xi8> loc(#loc) + %33 = vector.interleave %31, %32 : vector<64x8xi8> -> vector<64x16xi8> loc(#loc) + %34 = vector.transpose %33, [1, 0] : vector<64x16xi8> to vector<16x64xi8> loc(#loc) + %res1_0, %res2_1 = vector.deinterleave %34 : vector<16x64xi8> -> vector<16x32xi8> loc(#loc) + %35 = vector.transpose %res1_0, [1, 0] : vector<16x32xi8> to vector<32x16xi8> loc(#loc) + %36 = vector.transpose %res2_1, [1, 0] : vector<16x32xi8> to vector<32x16xi8> loc(#loc) + %37 = vector.interleave %35, %36 : vector<32x16xi8> -> vector<32x32xi8> loc(#loc) + %38 = vector.transpose %37, [1, 0] : vector<32x32xi8> to vector<32x32xi8> loc(#loc) + %39 = triton_cpu.dot %26, %38, %arg7, inputPrecision = tf32 : vector<32x32xi8> * vector<32x32xi8> -> vector<32x32xi32> loc(#loc) + %40 = tt.advance %arg8, [%c0_i32, %c1_i32, %c0_i32, %c0_i32] : > loc(#loc) + %41 = tt.advance %arg9, [%c1_i32, %c0_i32, %c0_i32, %c0_i32] : > loc(#loc) + scf.yield %39, %40, %41 : vector<32x32xi32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %22 = triton_cpu.extract_memref %18 : > -> memref> loc(#loc) + %23:2 = triton_cpu.extract_indices %18 : > -> index, index loc(#loc) + vector.transfer_write %21#0, %22[%23#0, %23#1] {in_bounds = [true, true]} : vector<32x32xi32>, memref> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) diff --git a/third_party/cpu/language/cpu/__init__.py b/third_party/cpu/language/cpu/__init__.py index e69de29bb2d1..d0618d5cd3e9 100644 --- a/third_party/cpu/language/cpu/__init__.py +++ b/third_party/cpu/language/cpu/__init__.py @@ -0,0 +1,3 @@ +from .utils import vnni_decode + +__all__ = ["vnni_decode"] diff --git a/third_party/cpu/language/cpu/utils.py b/third_party/cpu/language/cpu/utils.py new file mode 100644 index 000000000000..82538971a971 --- /dev/null +++ b/third_party/cpu/language/cpu/utils.py @@ -0,0 +1,22 @@ +from triton import jit +import triton.language as tl +from triton.language.core import builtin + + +@jit +def _vnni_decode(arg0): + tl.static_assert(len(arg0.shape) == 2) + tmp = arg0.reshape((arg0.shape[0], arg0.shape[1] // 2, 2)) + tmp1, tmp2 = tl.split(tmp) + return tl.join(tmp1.T, tmp2.T).reshape((arg0.shape[1] // 2, arg0.shape[0] * 2)).T + + +@builtin +def vnni_decode(arg0, _builder=None, _generator=None): + bitwidth = arg0.dtype.primitive_bitwidth + if bitwidth > 16: + raise ValueError("Expected 8-bit or 16-bit values for vnni_decode") + decoded = _generator.call_JitFunction(_vnni_decode, (arg0, ), kwargs={}) + if bitwidth == 8: + decoded = _generator.call_JitFunction(_vnni_decode, (decoded, ), kwargs={}) + return decoded diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp index 4ad5de863fb4..8fc432b9734e 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp @@ -56,15 +56,80 @@ Value getInitAccValue(Value val) { return forOp.getInitArgs()[initValIdx]; } -MemBuffer findInputBuffer(Value val, bool allowTransposed) { +namespace { + +// Check if val is a result of transpose operation. If it is, then return +// a source of that transpose operation. Otherwise, return nullptr. +Value getTransposedSrc(Value val) { + auto transposeOp = val.getDefiningOp(); + if (transposeOp) + return transposeOp.getVector(); + return nullptr; +} + +// We are looking for the following sequence: +// %tmp1, %tmp2 = vector.deinterleave %src +// %tmp3 = vector.transpose %tmp1, [1, 0] +// %tmp4 = vector.transpose %tmp2, [1, 0] +// %tmp5 = vector.interleave %tmp3, %tmp4 +// %val = vector.transpose %tmp5, [1, 0] +// and return %src if pattern matching succeeds. +Value getVnniSrcImpl(Value val) { + auto transposedVal = getTransposedSrc(val); + if (!transposedVal) + return nullptr; + + auto interleave = transposedVal.getDefiningOp(); + if (!interleave) + return nullptr; + + auto tmp1 = getTransposedSrc(interleave.getLhs()); + auto tmp2 = getTransposedSrc(interleave.getRhs()); + if (!tmp1 || !tmp2) + return nullptr; + + auto deinterleave1 = tmp1.getDefiningOp(); + auto deinterleave2 = tmp2.getDefiningOp(); + if (!deinterleave1 || deinterleave1 != deinterleave2 || + deinterleave1.getResult(0) != tmp1 || deinterleave2.getResult(1) != tmp2) + return nullptr; + + return deinterleave1.getSource(); +} + +} // namespace + +Value getVnniSrc(Value val) { + Type elemTy = getElementTypeOrSelf(val.getType()); + + // VNNI encoding is used for 8-bit and 16-bit values only. + if (elemTy.getIntOrFloatBitWidth() > 16) + return nullptr; + + // For 16-bit values VNNI encoding is a single interleave of + // subsequenct rows. For 8-bit values, it's applied twice. + Value encoded = getVnniSrcImpl(val); + if (encoded && elemTy.getIntOrFloatBitWidth() == 8) + encoded = getVnniSrcImpl(encoded); + + return encoded; +} + +MemBuffer findInputBuffer(Value val, bool allowTransposed, bool allowVnni) { MemBuffer buf; if (allowTransposed) { - auto transposeOp = val.getDefiningOp(); - if (transposeOp) { - val = transposeOp.getVector(); + auto transposed = getTransposedSrc(val); + if (transposed) { + val = transposed; buf.transposed = true; } + } else if (allowVnni) { + auto vnniVal = getVnniSrc(val); + if (vnniVal) { + val = vnniVal; + buf.vnni = true; + } } auto valLoad = val.getDefiningOp(); diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h index e26529d91882..2760ebd14fbb 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h @@ -24,6 +24,9 @@ struct MemBuffer { SmallVector step; // True if buffer holds transposed value. bool transposed = false; + // Ttue if buffer holds value in VNNI (interleaved to groups of 32bit) + // encoding. + bool vnni = false; bool empty() const { return !memRef; } }; @@ -48,10 +51,17 @@ template bool hasMaskOrBoundsCheck(T op) { return hasBoundsCheck || mask; } -// Search for a buffer holding required value. If allowTransposed is true, -// then buffer is allowed to hold both transposed and not transposed value. +// Search for a buffer holding required value. +// +// If allowTransposed is true, then buffer is allowed to hold both transposed +// and not transposed value. +// +// If allowVnni then buffer is allowed to hold value in both original and +// VNNI-encoded form. This flag is ignored if allowTransposed is true. +// // Return empty buffer if no memory holding value was found. -MemBuffer findInputBuffer(Value val, bool allowTransposed = false); +MemBuffer findInputBuffer(Value val, bool allowTransposed = false, + bool allowVnni = false); // Cast vector to a specified element type using ext or trunc // operations. Return the original value if it already matches @@ -67,6 +77,10 @@ MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, Value shiftIndex(Location loc, Value index, int64_t offs, PatternRewriter &rewriter); +// Check if val is a result of a sequence that performs VNNI decoding. +// If it is, then return the original encoded value. Otherwise, return nullptr. +Value getVnniSrc(Value val); + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp index 1b6dd9269ac1..11ce852e7570 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp @@ -402,9 +402,9 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, LDBG("Preparing buffer (interleave=" << interleave << ") for a vector: " << val); auto vecTy = cast(val.getType()); - MemBuffer inputBuf = findInputBuffer(val); + MemBuffer inputBuf = findInputBuffer(val, false, interleave); if (!inputBuf.empty()) { - if (interleave) { + if (interleave && !inputBuf.vnni) { LDBG(" Copying from the original memref with interleave: " << inputBuf.memRef); auto tmpBuf = allocateTmpBuffer(loc, getSwizzledRhsTileType(vecTy), @@ -426,7 +426,12 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); if (interleave) { - interleaveAndStore(loc, val, buf.memRef, rewriter); + auto interleavedVal = getVnniSrc(val); + if (interleavedVal) { + LDBG(" Using pre-encoding value: " << interleavedVal); + op_write(interleavedVal, buf.memRef, buf.indices); + } else + interleaveAndStore(loc, val, buf.memRef, rewriter); } else { op_write(val, buf.memRef, buf.indices); } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp index a39a93e42446..cc8ccfeb5374 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElemManipOps.cpp @@ -177,31 +177,20 @@ struct SplitOpConversion : public OpConversionPattern { auto src = rewriter.getRemappedValue(op.getSrc()); auto srcTy = cast(src.getType()); auto resTy = getTypeConverter()->convertType(op.getType(0)); + assert(srcTy.getShape().back() == 2); SmallVector results; if (srcTy.getRank() == 1) { results.push_back(rewriter.create(loc, src, 0)); results.push_back(rewriter.create(loc, src, 1)); + rewriter.replaceOp(op, results); } else { - SmallVector tmpShape({srcTy.getNumElements()}); + SmallVector tmpShape(srcTy.getShape().drop_back()); + tmpShape.back() *= 2; auto tmp = rewriter.create( loc, VectorType::get(tmpShape, srcTy.getElementType()), src); - - SmallVector evenIndices; - SmallVector oddIndices; - for (int64_t i = 0; i < srcTy.getNumElements(); i += 2) { - evenIndices.push_back(i); - oddIndices.push_back(i + 1); - } - - Value res1 = - rewriter.create(loc, tmp, tmp, evenIndices); - Value res2 = - rewriter.create(loc, tmp, tmp, oddIndices); - results.push_back(rewriter.create(loc, resTy, res1)); - results.push_back(rewriter.create(loc, resTy, res2)); + rewriter.replaceOpWithNewOp(op, tmp); } - rewriter.replaceOp(op, results); return success(); } }; From b8120678a82f158d02727da2b9217eb765a77449 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 23 Jan 2025 23:27:30 -0600 Subject: [PATCH 160/165] Update default target selection logic (#212) First try to lookup the Target in the given module. If it doesn't work, use the default target. And set it in the module. Rationale: Issue #207 --- python/src/llvm.cc | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 5977da3e9221..9040db7f402d 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -59,13 +59,17 @@ std::unique_ptr createTargetMachine(llvm::Module *module, std::string proc, bool enable_fp_fusion, const std::string &features, bool enable_fast_math = false) { - auto triple = getDefaultTargerOrProcessTriple(); - module->setTargetTriple(triple); std::string error; auto target = llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error); if (!target) { - throw std::runtime_error("target lookup error: " + error); + // Try to get the default target triple. + auto triple = getDefaultTargerOrProcessTriple(); + target = llvm::TargetRegistry::lookupTarget(triple, error); + if (!target) { + throw std::runtime_error("target lookup error: " + error); + } + module->setTargetTriple(triple); } llvm::TargetOptions opt; bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT"); From b0015d3335828b3c4397fc50077804b3cc9b8287 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Tue, 18 Feb 2025 18:11:59 +0100 Subject: [PATCH 161/165] [OneDNN] Ukernel Backend interface (#197) This PR introduces Ukernels api to allow usage of third party libraries such as OneDNN. Those libraries allows to call effective implementations for brgemm/transform and some other ops. So I am replacing triton_cpu.dot op when it's possible with call of kernel from library. Signed-off-by: Dmitrii Makarenko Co-authored-by: Ilya Enkovich --- .../Dialect/TritonCPU/IR/TritonCPUOps.td | 40 ++ python/test/unit/language/test_core.py | 3 +- python/triton/testing.py | 9 +- python/tutorials/cpu-blocked-matmul.py | 29 +- test/TritonCPU/dot-to-onednn.mlir | 237 +++++++++ third_party/cpu/CMakeLists.txt | 22 +- third_party/cpu/backend/compiler.py | 29 ++ .../cpu/include/TritonCPUToLLVM/Passes.h | 1 + .../cpu/include/TritonCPUToLLVM/Passes.td | 11 + .../include/TritonCPUTransforms/OptCommon.h | 4 +- .../cpu/include/TritonCPUTransforms/Passes.h | 7 + .../cpu/include/TritonCPUTransforms/Passes.td | 23 + .../cpu/lib/TritonCPUToLLVM/CMakeLists.txt | 1 + .../UkernelOpsToOneDNNLLVM.cpp | 210 ++++++++ .../lib/TritonCPUTransforms/CMakeLists.txt | 1 + .../ConvertDotOp/ConvertDotCommon.cpp | 17 +- .../ConvertDotOp/ConvertDotCommon.h | 11 +- .../ConvertDotOp/ConvertDotOpToUkernelOps.cpp | 480 ++++++++++++++++++ .../ConvertDotOp/ConvertDotToAMX.cpp | 22 +- .../ConvertDotOp/ConvertDotToFMA.cpp | 9 - third_party/cpu/runtime/runtime_onednn.cpp | 155 ++++++ third_party/cpu/triton_cpu.cc | 31 ++ 22 files changed, 1304 insertions(+), 48 deletions(-) create mode 100644 test/TritonCPU/dot-to-onednn.mlir create mode 100644 third_party/cpu/lib/TritonCPUToLLVM/UkernelOpsToOneDNNLLVM.cpp create mode 100644 third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotOpToUkernelOps.cpp create mode 100644 third_party/cpu/runtime/runtime_onednn.cpp diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td index b58fd9320354..fdb39e42529a 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUOps.td @@ -191,4 +191,44 @@ def TTC_DotOp : TTC_Op<"dot", [Pure, }]; } +def TTC_BrgemmCreate : TTC_Op<"brgemm_create", [NoMemoryEffect]> { + let summary = "Crete ukernels handles"; + + let description = [{For creation of ukernels, that can be used to replace op with dot-like sematnics}]; + + // M, N, K_k, batch_size, lda, ldb, ldc, dtypeA, dtypeB, dtypeC + let arguments = (ins + AnyTypeOf<[AnyInteger, Index]>:$M, + AnyTypeOf<[AnyInteger, Index]>:$N, + AnyTypeOf<[AnyInteger, Index]>:$K_k, + AnyTypeOf<[AnyInteger, Index]>:$batch_size, + AnyTypeOf<[AnyInteger, Index]>:$lda, + AnyTypeOf<[AnyInteger, Index]>:$ldb, + AnyTypeOf<[AnyInteger, Index]>:$ldc, + // TODO: Maybe Use properties + TypeAttr:$dtypeA, + TypeAttr:$dtypeB, + TypeAttr:$dtypeC + ); + + let results = (outs Index:$result); +} + +def TTC_BrgemmExecute : TTC_Op<"brgemm_execute", + [MemoryEffects<[MemRead, + MemWrite]>]> { + let summary = "Call ukernel with existing handle for passed ops"; + + let arguments = (ins + Index:$brgemm_kernel_hash, + Arg:$A_ptr, + Arg:$B_ptr, + Arg:$C_ptr, + AnyTypeOf<[AnyInteger, Index]>:$stepA, + AnyTypeOf<[AnyInteger, Index]>:$stepB, + AnyTypeOf<[AnyInteger, Index]>:$blockedBsize, + AnyTypeOf<[AnyInteger, Index]>:$numBatches + ); +} + #endif diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index e4e0b90819f1..afce8d7f65a3 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3528,7 +3528,8 @@ def get_test_dot_small_mn_fma_cases(): def get_test_dot_double_rate_cases(): if not is_hip_cdna(): return [] - return [(32, 32, 16, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), + return [(64, 64, 64, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None), + (32, 32, 16, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), (32, 32, 16, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None), (16, 16, 32, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), (16, 16, 32, 4, False, False, 'None', 'ieee', 'bfloat16', 'float32', 1, None)] diff --git a/python/triton/testing.py b/python/triton/testing.py index 5dadabaa1001..54d59de6066c 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -357,12 +357,17 @@ def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: b y_min, y_max = df[y + '-min'], df[y + '-max'] col = bench.styles[i][0] if bench.styles else None sty = bench.styles[i][1] if bench.styles else None - ax.plot(df[first_x], df[y], label=y, color=col, ls=sty) + ax.plot(df[first_x], df[y], color=col, ls=sty) + ax.annotate(y, xy=(df[first_x], df[y]), xytext=(1.02 * df[first_x], df[y]), color=col) if not y_min.isnull().all() and not y_max.isnull().all(): y_min = y_min.astype(float) y_max = y_max.astype(float) ax.fill_between(df[first_x], y_min, y_max, alpha=0.15, color=col) - ax.legend() + # ax.legend() + ax.minorticks_on() + ax.grid(which='minor', alpha=0.2) + ax.grid(which='major', alpha=0.5) + ax.set_xlabel(bench.xlabel or first_x) ax.set_ylabel(bench.ylabel) # ax.set_title(bench.plot_name) diff --git a/python/tutorials/cpu-blocked-matmul.py b/python/tutorials/cpu-blocked-matmul.py index e8f274d6c552..e04295863933 100644 --- a/python/tutorials/cpu-blocked-matmul.py +++ b/python/tutorials/cpu-blocked-matmul.py @@ -216,7 +216,7 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOC def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, bb: torch.Tensor, M, N, K, PREPACKED, - BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, PACKED_B, num_threads=0): + BLOCKED_A, TRANSPOSED_BLOCK_A, BLOCKED_B, TRANSPOSED_B, PACKED_B, num_threads=0, ukernels=None): #TODO: Currently masked load is not supported yet. assert (M % BLOCK_SIZE_M == 0) and (N % BLOCK_SIZE_N == 0) and ( K % BLOCK_SIZE_K == 0), "Masking currently not supported, Matrix dimensions must be multiples of block size" @@ -241,7 +241,7 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, GROUP_SIZE_M=GROUP_SIZE_M, # BLOCKED_A=BLOCKED_A, TRANSPOSED_BLOCK_A=TRANSPOSED_BLOCK_A, # BLOCKED_B=BLOCKED_B, TRANSPOSED_B=TRANSPOSED_B, PACKED_B=PACKED_B, # - OUT_DTYPE=tl.float32 if a.dtype.is_floating_point else tl.int32, num_threads=num_threads) + OUT_DTYPE=tl.float32 if a.dtype.is_floating_point else tl.int32, num_threads=num_threads, ukernels=ukernels) return c @@ -291,9 +291,10 @@ def matmul(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, ab: torch.Tensor, # but feel free to arrange this script as you wish to benchmark any other matrix shape. -def encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype): +def encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype, + ukernel): assert dtype == 'float32' or dtype == 'bfloat16' or dtype == 'float16' or dtype == 'int8' - return f"triton-cpu{'-ba' if blocked_a else ''}{'-ta' if transposed_a else ''}{'-bb' if blocked_b else ''}{'-tb' if transposed_b else ''}{'-pb' if packed_b else ''}{'-prepack' if prepack else ''}{'-st' if single_thread else ''}-{dtype}" + return f"triton-cpu{'-ba' if blocked_a else ''}{'-ta' if transposed_a else ''}{'-bb' if blocked_b else ''}{'-tb' if transposed_b else ''}{'-pb' if packed_b else ''}{'-prepack' if prepack else ''}{'-st' if single_thread else ''}-uk{ukernel}-{dtype}" def encode_torch_provider(single_thread, dtype): @@ -310,13 +311,20 @@ def decode_provider(provider): dtype = torch.float32 elif '-int8' in provider: dtype = torch.int8 + + ukernel = None + if '-ukNone' in provider: + ukernel = None + if '-ukOneDNN' in provider: + ukernel = "OneDNN" + if 'triton-cpu' in provider: backend = 'triton-cpu' elif 'torch-cpu-native' in provider: backend = 'torch-cpu-native' elif 'torch-cpu-compile' in provider: backend = 'torch-cpu-compile' - return backend, '-ba' in provider, '-ta' in provider, '-bb' in provider, '-tb' in provider, '-pb' in provider, '-prepack' in provider, '-st' in provider, dtype + return backend, '-ba' in provider, '-ta' in provider, '-bb' in provider, '-tb' in provider, '-pb' in provider, '-prepack' in provider, '-st' in provider, ukernel, dtype BLOCK_TRANSPOSE_A_OPTS = [(False, False)] @@ -324,13 +332,16 @@ def decode_provider(provider): PREPACK_OPTS = [False, True] SINGLE_THREAD_OPTS = [False] DTYPE_OPTS = [DTYPE] +UKERNEL_OPTS = [None, "OneDNN"] LINE_VALS = [ - encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype) + encode_triton_provider(blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype, + ukernel) for single_thread in SINGLE_THREAD_OPTS for blocked_a, transposed_a in BLOCK_TRANSPOSE_A_OPTS for blocked_b, transposed_b, packed_b in BLOCK_TRANSPOSE_PACK_B_OPTS for prepack in PREPACK_OPTS for dtype in DTYPE_OPTS + for ukernel in UKERNEL_OPTS if (blocked_a or blocked_b or not prepack) and (not packed_b or dtype != "float32") ] + [encode_torch_provider(single_thread, dtype) for dtype in DTYPE_OPTS for single_thread in SINGLE_THREAD_OPTS] LINE_NAMES = LINE_VALS @@ -356,7 +367,7 @@ def decode_provider(provider): def benchmark(M, N, K, provider): device = 'cpu' if 'cpu' in provider else 'cuda' - backend, blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, dtype = decode_provider( + backend, blocked_a, transposed_a, blocked_b, transposed_b, packed_b, prepack, single_thread, ukernel, dtype = decode_provider( provider) if dtype.is_floating_point: a = torch.randn((M, K), device=device, dtype=dtype) @@ -400,8 +411,8 @@ def benchmark(M, N, K, provider): elif backend == 'triton-cpu': ms, min_ms, max_ms = triton.testing.do_bench( lambda: matmul(a, b, c, a_tmp, b_tmp, M, N, K, prepack, blocked_a, transposed_a, blocked_b, transposed_b, - packed_b, num_threads=int(single_thread)), quantiles=quantiles, measure_time_with_hooks=True, - rep=1000) + packed_b, num_threads=int(single_thread), ukernels=ukernel), quantiles=quantiles, + measure_time_with_hooks=True, rep=1000) perf = lambda ms: 2 * M * N * K * 1e-9 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) diff --git a/test/TritonCPU/dot-to-onednn.mlir b/test/TritonCPU/dot-to-onednn.mlir new file mode 100644 index 000000000000..a9dc51008edb --- /dev/null +++ b/test/TritonCPU/dot-to-onednn.mlir @@ -0,0 +1,237 @@ +// RUN: triton-opt %s -split-input-file -triton-cpu-convert-dot-to-ukernels="ukernels=oneDNN" -cse | FileCheck %s + +// Replacement of a triton_cpu.dot operation with triton_cpu.brgemm_execute + +// CHECK-LABEL: @test_two_tiles_four_mulf +// CHECK: %[[LHS_MEMREF:.+]] = triton_cpu.extract_memref %0 : > -> memref<16x64xbf16, strided<[64, 1]>> +// CHECK-NEXT: %[[LHS_INDICES:.+]]:2 = triton_cpu.extract_indices %0 : > -> index, index +// CHECK: %[[RHS_MEMREF:.+]] = triton_cpu.extract_memref %1 : > -> memref<64x32xbf16, strided<[32, 1]>> +// CHECK-NEXT: %[[RHS_INDICES:.+]]:2 = triton_cpu.extract_indices %1 : > -> index, index +// CHECK: %[[ACC_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<16x32xf32> +// CHECK: %[[LHS_SUBVIEW:.+]] = memref.subview %[[LHS_MEMREF]][%[[LHS_INDICES]]#0, %[[LHS_INDICES]]#1] [16, 64] [1, 1] : memref<16x64xbf16, strided<[64, 1]>> to memref<16x64xbf16, strided<[64, 1], offset: ?>> +// CHECK: %[[RHS_SUBVIEW:.+]] = memref.subview %[[RHS_MEMREF]][%[[RHS_INDICES]]#0, %[[RHS_INDICES]]#1] [64, 32] [1, 1] : memref<64x32xbf16, strided<[32, 1]>> to memref<64x32xbf16, strided<[32, 1], offset: ?>> +// CHECK: %[[NONE1:.+]], %[[NONE2:.+]], %[[NONE3:.+]]:2, %[[LHS_STRIDES:.+]]:2 = memref.extract_strided_metadata %[[LHS_SUBVIEW]] : memref<16x64xbf16, strided<[64, 1], offset: ?>> -> memref, index, index, index, index, index +// CHECK: %[[NONE4:.+]], %[[NONE5:.+]], %[[NONE6:.+]]:2, %[[RHS_STRIDES:.+]]:2 = memref.extract_strided_metadata %[[RHS_SUBVIEW]] : memref<64x32xbf16, strided<[32, 1], offset: ?>> -> memref, index, index, index, index, index +// CHECK: %[[NONE7:.+]], %[[NONE8:.+]], %[[NONE0:.+]]:2, %[[ACC_STRIDES:.+]]:2 = memref.extract_strided_metadata %[[ACC_BUF]] : memref<16x32xf32> -> memref, index, index, index, index, index +// CHECK: %[[ONEDNN_HANDLE:.+]] = "triton_cpu.brgemm_create"(%c16{{.*}}, %c32{{.*}}, %c64{{.*}}, %c1{{.*}}, %[[LHS_STRIDES]]#0, %[[RHS_STRIDES]]#0, %[[ACC_STRIDES]]#0) <{dtypeA = vector<16x64xbf16>, dtypeB = vector<64x32xbf16>, dtypeC = f32}> : (i64, i64, i64, index, index, index, index) -> index +// CHECK: %[[BTW:.+]] = arith.constant 2 : i64 +// CHECK: %[[BLOCK:.+]] = arith.muli %c32{{.*}}, %c64{{.*}} : i64 +// CHECK: %[[BLOCKEDB_SIZE:.+]] = arith.muli %[[BLOCK]], %[[BTW]] : i64 +// CHECK: "triton_cpu.brgemm_execute"(%[[ONEDNN_HANDLE]], %[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]], %[[ACC_BUF]], %c0, %c0, %[[BLOCKEDB_SIZE]], %c1) : (index, memref<16x64xbf16, strided<[64, 1], offset: ?>>, memref<64x32xbf16, strided<[32, 1], offset: ?>>, memref<16x32xf32>, index, index, i64, index) -> () +// CHECK: %[[RES:.+]] = vector.transfer_read %[[ACC_BUF]][%c0, %c0], %cst_10 : memref<16x32xf32>, vector<16x32xf32> + +#loc = loc(unknown) +module { + tt.func public @test_two_tiles_four_mulf(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x32xf32> loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c16_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c64_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c16_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3 = triton_cpu.extract_memref %0 : > -> memref<16x64xbf16, strided<[64, 1]>> loc(#loc) + %4:2 = triton_cpu.extract_indices %0 : > -> index, index loc(#loc) + %5 = vector.transfer_read %3[%4#0, %4#1], %cst {in_bounds = [true, true]} : memref<16x64xbf16, strided<[64, 1]>>, vector<16x64xbf16> loc(#loc) + %6 = triton_cpu.extract_memref %1 : > -> memref<64x32xbf16, strided<[32, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %1 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<64x32xbf16, strided<[32, 1]>>, vector<64x32xbf16> loc(#loc) + %9 = triton_cpu.dot %5, %8, %cst_0, inputPrecision = ieee : vector<16x64xbf16> * vector<64x32xbf16> -> vector<16x32xf32> loc(#loc) + %10 = triton_cpu.extract_memref %2 : > -> memref<16x32xf32, strided<[32, 1]>> loc(#loc) + %11:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %9, %10[%11#0, %11#1] {in_bounds = [true, true]} : vector<16x32xf32>, memref<16x32xf32, strided<[32, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + +// ----- + +// More complicated case with a loop that can be replaced with single triton_cpu.brgemm_execute. + +// CHECK-LABEL: @test_loop_acc_two_blocks +// CHECK: %[[LHS_MEMREF:.+]] = triton_cpu.extract_memref %0 : > -> memref<64x128xbf16, strided<[128, 1]>> +// CHECK-NEXT: %[[LHS_INDICES:.+]]:2 = triton_cpu.extract_indices %0 : > -> index, index +// CHECK: %[[RHS_MEMREF:.+]] = triton_cpu.extract_memref %1 : > -> memref<128x32xbf16, strided<[32, 1]>> +// CHECK-NEXT: %[[RHS_INDICES:.+]]:2 = triton_cpu.extract_indices %1 : > -> index, index +// CHECK: %[[LOOP_LENGTH:.+]] = arith.subi %c2{{.*}}, %c0{{.*}} : i32 +// CHECK: %[[NUM_BATCHES_INT:.+]] = arith.divui %[[LOOP_LENGTH]], %c1{{.*}} : i32 +// CHECK: %[[NUM_BATCHES:.+]] = arith.index_cast %[[NUM_BATCHES_INT]] : i32 to index +// CHECK: %[[ACC_BUF:.+]] = memref.alloca() {alignment = 64 : i64} : memref<64x32xf32> +// CHECK: %[[LHS_SUBVIEW:.+]] = memref.subview %[[LHS_MEMREF]][%[[LHS_INDICES]]#0, %[[LHS_INDICES]]#1] [64, 64] [1, 1] : memref<64x128xbf16, strided<[128, 1]>> to memref<64x64xbf16, strided<[128, 1], offset: ?>> +// CHECK: %[[RHS_SUBVIEW:.+]] = memref.subview %[[RHS_MEMREF]][%[[RHS_INDICES]]#0, %[[RHS_INDICES]]#1] [64, 32] [1, 1] : memref<128x32xbf16, strided<[32, 1]>> to memref<64x32xbf16, strided<[32, 1], offset: ?>> +// CHECK: %[[NONE1:.+]], %[[NONE2:.+]], %[[NONE3:.+]]:2, %[[LHS_STRIDES:.+]]:2 = memref.extract_strided_metadata %[[LHS_SUBVIEW]] : memref<64x64xbf16, strided<[128, 1], offset: ?>> -> memref, index, index, index, index, index +// CHECK: %[[NONE4:.+]], %[[NONE5:.+]], %[[NONE6:.+]]:2, %[[RHS_STRIDES:.+]]:2 = memref.extract_strided_metadata %[[RHS_SUBVIEW]] : memref<64x32xbf16, strided<[32, 1], offset: ?>> -> memref, index, index, index, index, index +// CHECK: %[[NONE7:.+]], %[[NONE8:.+]], %[[NONE0:.+]]:2, %[[ACC_STRIDES:.+]]:2 = memref.extract_strided_metadata %[[ACC_BUF]] : memref<64x32xf32> -> memref, index, index, index, index, index +// CHECK: %[[ONEDNN_HANDLE:.+]] = "triton_cpu.brgemm_create"(%c64{{.*}}, %c32{{.*}}, %c64{{.*}}, %[[NUM_BATCHES]], %[[LHS_STRIDES]]#0, %[[RHS_STRIDES]]#0, %[[ACC_STRIDES]]#0) <{dtypeA = vector<64x64xbf16>, dtypeB = vector<64x32xbf16>, dtypeC = f32}> : (i64, i64, i64, index, index, index, index) -> index +// CHECK: %[[BTW:.+]] = arith.constant 2 : i64 +// CHECK: %[[BLOCK:.+]] = arith.muli %c32{{.*}}, %c64{{.*}} : i64 +// CHECK: %[[BLOCKEDB_SIZE:.+]] = arith.muli %[[BLOCK]], %[[BTW]] : i64 +// CHECK: %[[STEP:.+]] = arith.index_cast %c0{{.*}} : i32 to index +// CHECK: %[[OFFSET:.+]] = arith.muli %[[STEP]], %[[LHS_STRIDES]]#0 : index +// CHECK: %[[SZ:.+]] = arith.addi %c0, %[[OFFSET]] : index +// CHECK: %[[STEP1:.+]] = arith.index_cast %c64{{.*}} : i32 to index +// CHECK: %[[OFFSET1:.+]] = arith.muli %[[STEP1]], %[[LHS_STRIDES]]#1 : index +// CHECK: %[[LHS_STEP_ELEM:.+]] = arith.addi %[[SZ]], %[[OFFSET1]] : index +// CHECK: %[[CONST:.+]] = arith.constant 2 : index +// CHECK: %[[LHS_STEP:.+]] = arith.muli %[[LHS_STEP_ELEM]], %[[CONST]] : index +// CHECK: %[[R_OFFSET:.+]] = arith.muli %[[STEP1]], %[[RHS_STRIDES]]#0 : index +// CHECK: %[[SZ1:.+]] = arith.addi %c0, %[[R_OFFSET]] : index +// CHECK: %[[R_OFFSET1:.+]] = arith.muli %[[STEP]], %[[RHS_STRIDES]]#1 : index +// CHECK: %[[RHS_STEP_ELEM:.+]] = arith.addi %[[SZ1]], %[[R_OFFSET1]] : index +// CHECK: %[[RHS_STEP:.+]] = arith.muli %[[RHS_STEP_ELEM]], %[[CONST]] : index +// CHECK: "triton_cpu.brgemm_execute"(%[[ONEDNN_HANDLE]], %[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]], %[[ACC_BUF]], %[[LHS_STEP]], %[[RHS_STEP]], %[[BLOCKEDB_SIZE]], %[[NUM_BATCHES]]) : (index, memref<64x64xbf16, strided<[128, 1], offset: ?>>, memref<64x32xbf16, strided<[32, 1], offset: ?>>, memref<64x32xf32>, index, index, i64, index) -> () +// CHECK: %[[RES:.+]] = vector.transfer_read %[[ACC_BUF]][%c0, %c0], %cst_10 {in_bounds = [true, true]} : memref<64x32xf32>, vector<64x32xf32> + +#loc = loc(unknown) +module { + tt.func public @test_loop_acc_two_blocks(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %c2_i32 = arith.constant 2 : i32 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %c64_i32 = arith.constant 64 : i32 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<64x32xf32> loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c128_i64 = arith.constant 128 : i64 loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %0 = tt.make_tensor_ptr %arg0, [%c64_i64, %c128_i64], [%c128_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %1 = tt.make_tensor_ptr %arg1, [%c128_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %2 = tt.make_tensor_ptr %arg2, [%c64_i64, %c32_i64], [%c32_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > loc(#loc) + %3:3 = scf.for %arg3 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg4 = %cst_0, %arg5 = %0, %arg6 = %1) -> (vector<64x32xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %6 = triton_cpu.extract_memref %arg5 : > -> memref<64x128xbf16, strided<[128, 1]>> loc(#loc) + %7:2 = triton_cpu.extract_indices %arg5 : > -> index, index loc(#loc) + %8 = vector.transfer_read %6[%7#0, %7#1], %cst {in_bounds = [true, true]} : memref<64x128xbf16, strided<[128, 1]>>, vector<64x64xbf16> loc(#loc) + %9 = triton_cpu.extract_memref %arg6 : > -> memref<128x32xbf16, strided<[32, 1]>> loc(#loc) + %10:2 = triton_cpu.extract_indices %arg6 : > -> index, index loc(#loc) + %11 = vector.transfer_read %9[%10#0, %10#1], %cst {in_bounds = [true, true]} : memref<128x32xbf16, strided<[32, 1]>>, vector<64x32xbf16> loc(#loc) + %12 = triton_cpu.dot %8, %11, %arg4, inputPrecision = ieee : vector<64x64xbf16> * vector<64x32xbf16> -> vector<64x32xf32> loc(#loc) + %13 = tt.advance %arg5, [%c0_i32, %c64_i32] : > loc(#loc) + %14 = tt.advance %arg6, [%c64_i32, %c0_i32] : > loc(#loc) + scf.yield %12, %13, %14 : vector<64x32xf32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %4 = triton_cpu.extract_memref %2 : > -> memref<64x32xf32, strided<[32, 1]>> loc(#loc) + %5:2 = triton_cpu.extract_indices %2 : > -> index, index loc(#loc) + vector.transfer_write %3#0, %4[%5#0, %5#1] {in_bounds = [true, true]} : vector<64x32xf32>, memref<64x32xf32, strided<[32, 1]>> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) + + +// ----- + +// Case with a loop, that cannot be replaced as a whole and brgemm call should be +// injected in it's body. + +// CHECK-LABEL: @test_loop_with_transpose +// CHECK: %[[LHS_TPTR:.*]] = tt.make_tensor_ptr %arg0, [%17, %18, %c32_i64, %c32_i64], [%19, %c1024_i64, %c32_i64, %c1_i64], [%11, %c0_i32, %c0_i32, %c0_i32] {order = array} : > +// CHECK: %[[RHS_TPTR:.*]] = tt.make_tensor_ptr %arg1, [%18, %22, %c16_i64, %c64_i64], [%c1024_i64, %19, %c64_i64, %c1_i64], [%c0_i32, %13, %c0_i32, %c0_i32] {order = array} : > +// CHECK: %[[RES_TPTR:.*]] = tt.make_tensor_ptr %arg2, [%26, %27], [%27, %c1_i64], [%24, %25] {order = array} : > +// CHECK: %[[LHS_ALLOCA:.*]] = memref.alloca() {alignment = 64 : i64} : memref<32x32xbf16> +// CHECK: %[[RHS_ALLOCA:.*]] = memref.alloca() {alignment = 64 : i64} : memref<32x32xbf16> +// CHECK: %[[RES_ALLOCA:.*]] = memref.alloca() {alignment = 64 : i64} : memref<32x32xf32> +// CHECK: %31:3 = scf.for %[[IV1:.*]] = %c0_i32 to %30 step %c1_i32 iter_args(%[[RES_IV:.*]] = %cst_0, %[[LHS_IV:.*]] = %[[LHS_TPTR]], %[[RHS_IV:.*]] = %[[RHS_TPTR]]) -> (vector<32x32xf32>, !tt.ptr>, !tt.ptr>) : i32 { +// CHECK-NEXT: %[[LHS_MEMREF:.+]] = triton_cpu.extract_memref %[[LHS_IV]] : > -> memref> +// CHECK-NEXT: %[[LHS_INDICES:.+]]:4 = triton_cpu.extract_indices %[[LHS_IV]] : > -> index, index, index, index +// CHECK-NEXT: %[[LHS_VEC:.+]] = vector.transfer_read %[[LHS_MEMREF]][%[[LHS_INDICES]]#0, %[[LHS_INDICES]]#1, %[[LHS_INDICES]]#2, %[[LHS_INDICES]]#3], %cst {in_bounds = [true, true]} : memref>, vector<32x32xbf16> +// CHECK-NEXT: %[[RHS_MEMREF:.+]] = triton_cpu.extract_memref %[[RHS_IV]] : > -> memref> +// CHECK-NEXT: %[[RHS_INDICES:.+]]:4 = triton_cpu.extract_indices %[[RHS_IV]] : > -> index, index, index, index +// CHECK-NEXT: %[[RHS_VEC:.+]] = vector.transfer_read %[[RHS_MEMREF]][%[[RHS_INDICES]]#0, %[[RHS_INDICES]]#1, %[[RHS_INDICES]]#2, %[[RHS_INDICES]]#3], %cst {in_bounds = [true, true]} : memref>, vector<16x64xbf16> +// CHECK-NEXT: %[[LHS_VEC_T:.+]] = vector.transpose %[[LHS_VEC]], [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> +// CHECK-NEXT: %[[RHS_VEC_D_0:.+]], %[[RHS_VEC_D_1:.+]] = vector.deinterleave %[[RHS_VEC]] : vector<16x64xbf16> -> vector<16x32xbf16> +// CHECK-NEXT: %[[RHS_VEC_D_0_T:.+]] = vector.transpose %[[RHS_VEC_D_0]], [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> +// CHECK-NEXT: %[[RHS_VEC_D_1_T:.+]] = vector.transpose %[[RHS_VEC_D_1]], [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> +// CHECK-NEXT: %[[RHS_VEC_D_T:.+]] = vector.interleave %[[RHS_VEC_D_0_T]], %[[RHS_VEC_D_1_T]] : vector<32x16xbf16> -> vector<32x32xbf16> +// CHECK-NEXT: %[[RHS_VEC_D:.+]] = vector.transpose %[[RHS_VEC_D_T]], [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> +// CHECK: vector.transfer_write %[[LHS_VEC_T]], %[[LHS_ALLOCA]][%c0, %c0] {in_bounds = [true, true]} : vector<32x32xbf16>, memref<32x32xbf16> +// CHECK-NEXT: vector.transfer_write %[[RHS_VEC_D]], %[[RHS_ALLOCA]][%c0, %c0] {in_bounds = [true, true]} : vector<32x32xbf16>, memref<32x32xbf16> +// CHECK: %[[LHS_SUBVIEW:.+]] = memref.subview %[[LHS_ALLOCA]][0, 0] [32, 32] [1, 1] : memref<32x32xbf16> to memref<32x32xbf16, strided<[32, 1]>> +// CHECK-NEXT: %[[RHS_SUBVIEW:.+]] = memref.subview %[[RHS_ALLOCA]][0, 0] [32, 32] [1, 1] : memref<32x32xbf16> to memref<32x32xbf16, strided<[32, 1]>> +// CHECK-NEXT: %[[NONE1:.+]], %[[NONE2:.+]], %[[NONE3:.+]]:2, %[[LHS_STRIDES:.+]]:2 = memref.extract_strided_metadata %[[LHS_SUBVIEW]] : memref<32x32xbf16, strided<[32, 1]>> -> memref, index, index, index, index, index +// CHECK-NEXT: %[[NONE4:.+]], %[[NONE5:.+]], %[[NONE6:.+]]:2, %[[RHS_STRIDES:.+]]:2 = memref.extract_strided_metadata %[[RHS_SUBVIEW]] : memref<32x32xbf16, strided<[32, 1]>> -> memref, index, index, index, index, index +// CHECK-NEXT: %[[NONE7:.+]], %[[NONE8:.+]], %[[NONE0:.+]]:2, %[[ACC_STRIDES:.+]]:2 = memref.extract_strided_metadata %[[RES_ALLOCA]] : memref<32x32xf32> -> memref, index, index, index, index, index +// CHECK-NEXT: %[[ONEDNN_HANDLE:.+]] = "triton_cpu.brgemm_create"(%c32{{.*}}, %c32{{.*}}, %c32{{.*}}, %c1, %[[LHS_STRIDES]]#0, %[[RHS_STRIDES]]#0, %[[ACC_STRIDES]]#0) <{dtypeA = vector<32x32xbf16>, dtypeB = vector<32x32xbf16>, dtypeC = f32}> : (i64, i64, i64, index, index, index, index) -> index +// CHECK-NEXT: %[[BTW:.+]] = arith.constant 2 : i64 +// CHECK-NEXT: %[[BLOCK:.+]] = arith.muli %c32{{.*}}, %c32{{.*}} : i64 +// CHECK-NEXT: %[[BLOCKEDB_SIZE:.+]] = arith.muli %[[BLOCK]], %[[BTW]] : i64 +// CHECK-NEXT: "triton_cpu.brgemm_execute"(%[[ONEDNN_HANDLE]], %[[LHS_SUBVIEW]], %[[RHS_SUBVIEW]], %[[RES_ALLOCA]], %c0{{.*}}, %c0{{.*}}, %[[BLOCKEDB_SIZE]], %c1) : (index, memref<32x32xbf16, strided<[32, 1]>>, memref<32x32xbf16, strided<[32, 1]>>, memref<32x32xf32>, index, index, i64, index) -> () +// CHECK-NEXT: %[[LHS_IV_UPD:.*]] = tt.advance %[[LHS_IV]], [%c0_i32, %c1_i32, %c0_i32, %c0_i32] : > +// CHECK-NEXT: %[[RHS_IV_UPD:.*]] = tt.advance %[[RHS_IV]], [%c1_i32, %c0_i32, %c0_i32, %c0_i32] : > +// CHECK-NEXT: scf.yield %[[RES_IV]], %[[LHS_IV_UPD]], %[[RHS_IV_UPD]] : vector<32x32xf32>, !tt.ptr>, !tt.ptr> +// CHECK-NEXT: } +// CHECK: %[[RES:.+]] = vector.transfer_read %[[RES_ALLOCA]][%c0, %c0], %cst_3 {in_bounds = [true, true]} : memref<32x32xf32>, vector<32x32xf32> + + +#loc = loc(unknown) +module { + tt.func public @test_loop_with_transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg1: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg2: !tt.ptr {tt.divisibility = 16 : i32} loc(unknown), %arg3: i32 {tt.divisibility = 16 : i32} loc(unknown), %arg4: i32 {tt.divisibility = 16 : i32} loc(unknown), %arg5: i32 {tt.divisibility = 16 : i32} loc(unknown)) attributes {noinline = false} { + %cst = arith.constant 0.000000e+00 : bf16 loc(#loc) + %c31_i32 = arith.constant 31 : i32 loc(#loc) + %c1024_i64 = arith.constant 1024 : i64 loc(#loc) + %cst_0 = arith.constant dense<0.000000e+00> : vector<32x32xf32> loc(#loc) + %c1_i64 = arith.constant 1 : i64 loc(#loc) + %c64_i64 = arith.constant 64 : i64 loc(#loc) + %c16_i64 = arith.constant 16 : i64 loc(#loc) + %c0_i32 = arith.constant 0 : i32 loc(#loc) + %c32_i64 = arith.constant 32 : i64 loc(#loc) + %c1_i32 = arith.constant 1 : i32 loc(#loc) + %c32_i32 = arith.constant 32 : i32 loc(#loc) + %c8_i32 = arith.constant 8 : i32 loc(#loc) + %0 = tt.get_program_id x : i32 loc(#loc) + %1 = arith.addi %arg3, %c31_i32 : i32 loc(#loc) + %2 = arith.divsi %1, %c32_i32 : i32 loc(#loc) + %3 = arith.addi %arg4, %c31_i32 : i32 loc(#loc) + %4 = arith.divsi %3, %c32_i32 : i32 loc(#loc) + %5 = arith.muli %4, %c8_i32 : i32 loc(#loc) + %6 = arith.divsi %0, %5 : i32 loc(#loc) + %7 = arith.muli %6, %c8_i32 : i32 loc(#loc) + %8 = arith.subi %2, %7 : i32 loc(#loc) + %9 = arith.minsi %8, %c8_i32 : i32 loc(#loc) + %10 = arith.remsi %0, %9 : i32 loc(#loc) + %11 = arith.addi %7, %10 : i32 loc(#loc) + %12 = arith.remsi %0, %5 : i32 loc(#loc) + %13 = arith.divsi %12, %9 : i32 loc(#loc) + %14 = arith.divsi %arg3, %c32_i32 : i32 loc(#loc) + %15 = arith.divsi %arg5, %c32_i32 : i32 loc(#loc) + %16 = arith.muli %arg5, %c32_i32 : i32 loc(#loc) + %17 = arith.extsi %14 : i32 to i64 loc(#loc) + %18 = arith.extsi %15 : i32 to i64 loc(#loc) + %19 = arith.extsi %16 : i32 to i64 loc(#loc) + %20 = tt.make_tensor_ptr %arg0, [%17, %18, %c32_i64, %c32_i64], [%19, %c1024_i64, %c32_i64, %c1_i64], [%11, %c0_i32, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %21 = arith.divsi %arg4, %c32_i32 : i32 loc(#loc) + %22 = arith.extsi %21 : i32 to i64 loc(#loc) + %23 = tt.make_tensor_ptr %arg1, [%18, %22, %c16_i64, %c64_i64], [%c1024_i64, %19, %c64_i64, %c1_i64], [%c0_i32, %13, %c0_i32, %c0_i32] {order = array} : > loc(#loc) + %24 = arith.muli %11, %c32_i32 : i32 loc(#loc) + %25 = arith.muli %13, %c32_i32 : i32 loc(#loc) + %26 = arith.extsi %arg3 : i32 to i64 loc(#loc) + %27 = arith.extsi %arg4 : i32 to i64 loc(#loc) + %28 = tt.make_tensor_ptr %arg2, [%26, %27], [%27, %c1_i64], [%24, %25] {order = array} : > loc(#loc) + %29 = arith.addi %arg5, %c31_i32 : i32 loc(#loc) + %30 = arith.divsi %29, %c32_i32 : i32 loc(#loc) + %31:3 = scf.for %arg6 = %c0_i32 to %30 step %c1_i32 iter_args(%arg7 = %cst_0, %arg8 = %20, %arg9 = %23) -> (vector<32x32xf32>, !tt.ptr>, !tt.ptr>) : i32 { + %34 = triton_cpu.extract_memref %arg8 : > -> memref> loc(#loc) + %35:4 = triton_cpu.extract_indices %arg8 : > -> index, index, index, index loc(#loc) + %36 = vector.transfer_read %34[%35#0, %35#1, %35#2, %35#3], %cst {in_bounds = [true, true]} : memref>, vector<32x32xbf16> loc(#loc) + %37 = triton_cpu.extract_memref %arg9 : > -> memref> loc(#loc) + %38:4 = triton_cpu.extract_indices %arg9 : > -> index, index, index, index loc(#loc) + %39 = vector.transfer_read %37[%38#0, %38#1, %38#2, %38#3], %cst {in_bounds = [true, true]} : memref>, vector<16x64xbf16> loc(#loc) + %40 = vector.transpose %36, [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> loc(#loc) + %res1, %res2 = vector.deinterleave %39 : vector<16x64xbf16> -> vector<16x32xbf16> loc(#loc) + %41 = vector.transpose %res1, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %42 = vector.transpose %res2, [1, 0] : vector<16x32xbf16> to vector<32x16xbf16> loc(#loc) + %43 = vector.interleave %41, %42 : vector<32x16xbf16> -> vector<32x32xbf16> loc(#loc) + %44 = vector.transpose %43, [1, 0] : vector<32x32xbf16> to vector<32x32xbf16> loc(#loc) + %45 = triton_cpu.dot %40, %44, %arg7, inputPrecision = tf32 : vector<32x32xbf16> * vector<32x32xbf16> -> vector<32x32xf32> loc(#loc) + %46 = tt.advance %arg8, [%c0_i32, %c1_i32, %c0_i32, %c0_i32] : > loc(#loc) + %47 = tt.advance %arg9, [%c1_i32, %c0_i32, %c0_i32, %c0_i32] : > loc(#loc) + scf.yield %45, %46, %47 : vector<32x32xf32>, !tt.ptr>, !tt.ptr> loc(#loc) + } loc(#loc) + %32 = triton_cpu.extract_memref %28 : > -> memref> loc(#loc) + %33:2 = triton_cpu.extract_indices %28 : > -> index, index loc(#loc) + vector.transfer_write %31#0, %32[%33#0, %33#1] {in_bounds = [true, true]} : vector<32x32xf32>, memref> loc(#loc) + tt.return loc(#loc) + } loc(#loc) +} loc(#loc) diff --git a/third_party/cpu/CMakeLists.txt b/third_party/cpu/CMakeLists.txt index c62a4cda03ad..e24d25a65019 100644 --- a/third_party/cpu/CMakeLists.txt +++ b/third_party/cpu/CMakeLists.txt @@ -1,14 +1,30 @@ +find_package(dnnl CONFIG) +if (dnnl_FOUND) + message(STATUS "Found OneDNN/DNNL") + add_compile_definitions(ONEDNN_AVAILABLE) + get_target_property(dnnl_include DNNL::dnnl INTERFACE_INCLUDE_DIRECTORIES) + # currently used only in triton_cpu.cc and in ConvertDotToOneDNN + include_directories(${dnnl_include}) +else () + message(STATUS "Could NOT find OneDNN/DNNL") +endif() + include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include) add_subdirectory(include) add_subdirectory(lib) if(TRITON_BUILD_PYTHON_MODULE) add_triton_plugin(TritonCPU ${CMAKE_CURRENT_SOURCE_DIR}/triton_cpu.cc LINK_LIBS TritonCPUToLLVM TritonCPUTransforms) - target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation PRIVATE Python3::Module pybind11::headers) + target_link_libraries(TritonCPU PUBLIC MLIRVectorToSCF MLIRAffineToStandard MLIRMathToLibm MLIRAMXToLLVMIRTranslation MLIRMemRefTransforms PRIVATE Python3::Module pybind11::headers) endif() -add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) -target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport) +if (dnnl_FOUND) + add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp ${CMAKE_CURRENT_SOURCE_DIR}/runtime/runtime_onednn.cpp) + target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport DNNL::dnnl) +else () + add_library(TritonCPURuntime SHARED ${CMAKE_CURRENT_SOURCE_DIR}/runtime/cpu_runtime.cpp) + target_link_libraries(TritonCPURuntime PRIVATE LLVMSupport) +endif() # Build and link sleef set(SLEEF_BUILD_SHARED_LIBS ON CACHE BOOL "Build sleef shared lib" FORCE) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index ad26f6d37157..6736976cc168 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -20,6 +20,7 @@ def min_dot_size(target: GPUTarget): VecLib = cpu.passes.ttcpuir.VecLib +Ukernels = cpu.passes.ttcpuir.Ukernels @dataclass(frozen=True) @@ -51,6 +52,7 @@ class CPUOptions: sanitize_overflow: bool = False # TODO: We may introduce CPU-specific options like # of cores. + ukernels: str = None def __post_init__(self): pass @@ -72,6 +74,25 @@ def get_vec_lib(self) -> VecLib: ) return vec_lib + def get_ukernels(self) -> Ukernels: + if self.ukernels is None: + return None + ukernels = Ukernels.__members__.get(self.ukernels, None) + if ukernels is None: + raise ValueError( + f"Unexpected value for ukernels: {self.ukernels}, should be one of {{{', '.join(Ukernels.__members__.keys())}}}" + ) + + if ukernels == Ukernels.OneDNN and not cpu.onednn_available(): + import warnings + # Warns on each compileation + warnings.simplefilter('once', category=UserWarning) + warnings.warn( + "Warning! Triton build was made without OneDNN support. Check if \"CMAKE_PREFIX_PATH\" contains path to OneDNN during build. \n\t -------OneDNN will NOT be used-------", + stacklevel=1) + return None + return ukernels + class CPUBackend(BaseBackend): @@ -163,6 +184,11 @@ def make_tttcir(self, mod, metadata, opt): cpu.passes.ttcpuir.add_triton_cpu_canonicalizer(pm) cpu.passes.ttcpuir.add_optimize_masks(pm) passes.common.add_canonicalizer(pm) + if (ukernels := opt.get_ukernels()): + # For further analysis simplification + cpu.passes.ttcpuir.add_loop_invariant_code_motion(pm) + cpu.passes.ttcpuir.add_convert_dot_to_ukernels(pm, ukernels) + passes.common.add_cse(pm) convert_bf16_dot_product = ((self.cpu_arch == "aarch64" or self.cpu_arch == "armv8") and 'fp-armv8' in self.cpu_features and 'neon' in self.cpu_features) if convert_bf16_dot_product: @@ -204,7 +230,10 @@ def make_llir(self, src, metadata, options): # TritonCPU -> LLVM-IR (MLIR) pm = ir.pass_manager(mod.context) pm.enable_debug() + if options.get_ukernels() == Ukernels.OneDNN: + cpu.passes.ttcpuir.add_ukernels_to_onednn_llvmir(pm) cpu.passes.ttcpuir.add_lower_vector_multi_dim(pm) + cpu.passes.ttcpuir.add_expand_strided_metadata(pm) cpu.passes.ttcpuir.add_vector_to_scf(pm, True, 1, False) cpu.passes.ttcpuir.add_lower_affine(pm) passes.convert.add_scf_to_cf(pm) diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.h b/third_party/cpu/include/TritonCPUToLLVM/Passes.h index cc29821c580c..bccd13d873f3 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.h +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.h @@ -31,6 +31,7 @@ std::unique_ptr> createGetProgramIdOpToLLVMPass(); std::unique_ptr> createLowerMultiReductionPass(); std::unique_ptr> createAtomicOpsToLLVMPass(); std::unique_ptr> createDebugOpsToLLVMPass(); +std::unique_ptr> createUkernelOpsToOneDNNLLVMPass(); std::unique_ptr> createMathToVecLibPass(VecLib lib = VecLib::Sleef, std::set cpu_features = {}); diff --git a/third_party/cpu/include/TritonCPUToLLVM/Passes.td b/third_party/cpu/include/TritonCPUToLLVM/Passes.td index 3ee08d9968b2..ddff5882a4cd 100644 --- a/third_party/cpu/include/TritonCPUToLLVM/Passes.td +++ b/third_party/cpu/include/TritonCPUToLLVM/Passes.td @@ -66,6 +66,17 @@ def DebugOpsToLLVM : Pass<"triton-cpu-debug-ops-to-llvm", "mlir::ModuleOp"> { "mlir::triton::TritonDialect"]; } +def UkernelOpsToOneDNNLLVM : Pass<"triton-cpu-ukernels-to-onednn-llvm", "mlir::ModuleOp"> { + let summary = "Convert ukernel operations to OneDNN LLVM runtime calls."; + let description = [{}]; + let constructor = "mlir::triton::cpu::createUkernelOpsToOneDNNLLVMPass()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::LLVM::LLVMDialect", + "mlir::triton::cpu::TritonCPUDialect", + "mlir::triton::TritonDialect"]; +} + def MathToVecLib : Pass<"triton-cpu-math-to-vec-lib", "mlir::ModuleOp"> { let summary = "Convert vector math operations to vector libm or sleef calls."; let description = [{ diff --git a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h index c3fe3973ce0b..b7f53d35071f 100644 --- a/third_party/cpu/include/TritonCPUTransforms/OptCommon.h +++ b/third_party/cpu/include/TritonCPUTransforms/OptCommon.h @@ -134,6 +134,8 @@ inline Value shapeCast(Location loc, Value in, #define op_subf(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_muli(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_mulf(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_divsi(lhs, rhs) rewriter.create(loc, lhs, rhs) +#define op_divui(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_bitcast(ty, val) rewriter.create(loc, ty, val) #define op_lshr(lhs, rhs) rewriter.create(loc, lhs, rhs) #define op_shl(lhs, rhs) rewriter.create(loc, lhs, rhs) @@ -160,7 +162,7 @@ inline Value shapeCast(Location loc, Value in, #define op_extract(vec, idx) rewriter.create(loc, vec, idx) #define op_store(val, mem, idx) \ rewriter.create(loc, val, mem, idx) - +#define op_index_cast(ty, val) rewriter.create(loc, ty, val) #define op_icmp_eq(lhs, rhs) \ rewriter.create(loc, arith::CmpIPredicate::eq, lhs, rhs) #define op_icmp_ne(lhs, rhs) \ diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.h b/third_party/cpu/include/TritonCPUTransforms/Passes.h index f0c7a777e5fa..30ce9a4fb3e8 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.h +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.h @@ -15,6 +15,10 @@ template class OperationPass; namespace triton { namespace cpu { +enum class Ukernels { + OneDNN, +}; + #define GEN_PASS_DECL #include "cpu/include/TritonCPUTransforms/Passes.h.inc" @@ -40,6 +44,9 @@ std::unique_ptr> createConvertDotToFMA(); std::unique_ptr> createConvertDotGeneric(); std::unique_ptr> createCanonicalize(); +std::unique_ptr> createConvertDotOpToUkernelOps( + Ukernels ukernels = mlir::triton::cpu::Ukernels::OneDNN); + #define GEN_PASS_REGISTRATION #include "cpu/include/TritonCPUTransforms/Passes.h.inc" diff --git a/third_party/cpu/include/TritonCPUTransforms/Passes.td b/third_party/cpu/include/TritonCPUTransforms/Passes.td index 00c01a4725ce..5bed4074e817 100644 --- a/third_party/cpu/include/TritonCPUTransforms/Passes.td +++ b/third_party/cpu/include/TritonCPUTransforms/Passes.td @@ -130,6 +130,29 @@ def ConvertDotToFMA : Pass<"triton-cpu-convert-dot-to-fma", "mlir::ModuleOp"> { let description = [{ }]; let constructor = "mlir::triton::cpu::createConvertDotToFMA()"; + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + +def ConvertDotOpToUkernelOps : Pass<"triton-cpu-convert-dot-to-ukernels", "mlir::ModuleOp"> { + let summary = "Convert dot product op to ukernel ops."; + let description = [{ + This pass is used to lower DotOp operations to ukernel ops. + }]; + + let options = [ + Option<"ukernels", "ukernels", + "mlir::triton::cpu::Ukernels", /*default*/"mlir::triton::cpu::Ukernels::OneDNN", + "Ukernels provider to be used for replacement of dot op (OneDNN/Xsmm/etc).", + [{::llvm::cl::values( + clEnumValN(mlir::triton::cpu::Ukernels::OneDNN, "oneDNN", + "Use OneDNN as a ukernels provider") + )}]>, + ]; + + let constructor = "mlir::triton::cpu::createConvertDotOpToUkernelOps()"; let dependentDialects = ["mlir::arith::ArithDialect", "mlir::vector::VectorDialect", diff --git a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt index 5448d81937f4..c0251a42c446 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUToLLVM/CMakeLists.txt @@ -1,6 +1,7 @@ add_triton_library(TritonCPUToLLVM AtomicOpsToLLVM.cpp DebugOpsToLLVM.cpp + UkernelOpsToOneDNNLLVM.cpp FuncOpToLLVM.cpp GetProgramIdOpToLLVM.cpp LowerMultiReduction.cpp diff --git a/third_party/cpu/lib/TritonCPUToLLVM/UkernelOpsToOneDNNLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/UkernelOpsToOneDNNLLVM.cpp new file mode 100644 index 000000000000..dc56651e968b --- /dev/null +++ b/third_party/cpu/lib/TritonCPUToLLVM/UkernelOpsToOneDNNLLVM.cpp @@ -0,0 +1,210 @@ +#include "TypeConverter.h" +#include "Utility.h" + +#include "cpu/include/TritonCPUToLLVM/Passes.h" + +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUOps.h.inc" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#if defined(ONEDNN_AVAILABLE) +#include "oneapi/dnnl/dnnl_types.h" +#endif + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_UKERNELOPSTOONEDNNLLVM +#include "cpu/include/TritonCPUToLLVM/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +#if defined(ONEDNN_AVAILABLE) +#include "oneapi/dnnl/dnnl_config.h" +#endif +void assert_on_onednn_missing() { +#if !defined(DNNL_EXPERIMENTAL_UKERNEL) + assert(false && "No OneDNN with uKernels available. Pass will be redundant."); +#endif +} + +inline Value intLLVMConst(Location loc, Type ty, int64_t val, + PatternRewriter &rewriter) { + return rewriter.create( + loc, IntegerAttr::get(getElementTypeOrSelf(ty), val)); +} + +static inline int64_t getDnnlDataTypeVal(Type ty) { +#if defined(DNNL_EXPERIMENTAL_UKERNEL) + ty = getElementTypeOrSelf(ty); + if (ty.isF32()) + return static_cast(dnnl_f32); + if (ty.isF64()) + return static_cast(dnnl_f64); + if (ty.isBF16()) + return static_cast(dnnl_bf16); + if (ty.isF16()) + return static_cast(dnnl_f16); +#endif + assert_on_onednn_missing(); + llvm_unreachable("Unexpected type for conversion to DNNL type."); +} + +class TritonLLVMConversionTarget : public ConversionTarget { +public: + explicit TritonLLVMConversionTarget(MLIRContext &ctx) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalOp(); + } +}; + +LLVM::LLVMFuncOp getFuncDecl(ConversionPatternRewriter &rewriter, + StringRef funcName, SmallVector argsType, + Type resultType) { + auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); + Operation *funcOp = moduleOp.lookupSymbol(funcName); + if (funcOp) + return cast(*funcOp); + + auto *ctx = rewriter.getContext(); + + auto funcType = + LLVM::LLVMFunctionType::get(resultType, argsType, /*isVarArg*/ false); + + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + return rewriter.create(UnknownLoc::get(ctx), funcName, + funcType); +} + +struct BrgemmCreateConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(BrgemmCreate brgemmOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = brgemmOp.getLoc(); + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + + std::string dispatchName = "create_brgemm"; + + auto lhsDnnType = i64_val(getDnnlDataTypeVal(adaptor.getDtypeA())); + auto rhsDnnType = i64_val(getDnnlDataTypeVal(adaptor.getDtypeB())); + auto accDnnType = i64_val(getDnnlDataTypeVal(adaptor.getDtypeC())); + + auto brgemmArgs = + SmallVector{adaptor.getM(), adaptor.getN(), + adaptor.getKK(), adaptor.getBatchSize(), + adaptor.getLda(), adaptor.getLdb(), + adaptor.getLdc(), lhsDnnType, + rhsDnnType, accDnnType}; + SmallVector brgemmArgTypes{i64_ty, i64_ty, i64_ty, i64_ty, i64_ty, + i64_ty, i64_ty, i64_ty, i64_ty, i64_ty}; + + auto dispatched = LLVM::createLLVMCallOp( + rewriter, loc, + getFuncDecl( + rewriter, dispatchName, brgemmArgTypes, + getTypeConverter()->convertType(brgemmOp.getResult().getType())), + brgemmArgs); + + rewriter.replaceOp(brgemmOp, dispatched.getResult()); + return success(); + }; +}; + +struct BrgemmExecuteConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(BrgemmExecute brgemmOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = brgemmOp.getLoc(); + auto ctx = rewriter.getContext(); + + std::string invokeName = "brgemm_execute"; + + auto brgemm_kernel_hash_ptr = rewriter.create( + loc, ptr_ty(ctx), adaptor.getBrgemmKernelHash()); + + auto brgemmArgs = SmallVector{ + // tf_kernel_hash_ptr, + brgemm_kernel_hash_ptr, + MemRefDescriptor(adaptor.getAPtr()) + .bufferPtr(rewriter, loc, *getTypeConverter(), + cast(brgemmOp.getAPtr().getType())), + MemRefDescriptor(adaptor.getBPtr()) + .bufferPtr(rewriter, loc, *getTypeConverter(), + cast(brgemmOp.getBPtr().getType())), + MemRefDescriptor(adaptor.getCPtr()) + .bufferPtr(rewriter, loc, *getTypeConverter(), + cast(brgemmOp.getCPtr().getType())), + adaptor.getStepA(), + adaptor.getStepB(), + adaptor.getBlockedBsize(), + adaptor.getNumBatches()}; + + auto brgemmArgTypes = + SmallVector{ptr_ty(ctx), ptr_ty(ctx), ptr_ty(ctx), ptr_ty(ctx), + i64_ty, i64_ty, i64_ty, i64_ty}; + + auto dispatched = LLVM::createLLVMCallOp( + rewriter, loc, + getFuncDecl(rewriter, invokeName, brgemmArgTypes, void_ty(ctx)), + brgemmArgs); + + rewriter.replaceOp(brgemmOp, dispatched); + return success(); + }; +}; + +struct UkernelOpsToOneDNNLLVM + : public triton::cpu::impl::UkernelOpsToOneDNNLLVMBase< + UkernelOpsToOneDNNLLVM> { + UkernelOpsToOneDNNLLVM() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + mlir::LowerToLLVMOptions option(context); + TritonCPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMConversionTarget conversionTarget(*context); + + RewritePatternSet patterns(context); + + patterns.add( + typeConverter); + + if (failed(applyPartialConversion(mod, conversionTarget, + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // anonymous namespace + +namespace mlir::triton::cpu { + +std::unique_ptr> createUkernelOpsToOneDNNLLVMPass() { + return std::make_unique(); +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt index c6e9b4ed69e6..ff55dbf54dff 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt +++ b/third_party/cpu/lib/TritonCPUTransforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonCPUTransforms ConvertDotOp/ConvertDotGeneric.cpp ConvertDotOp/ConvertDotToAMX.cpp ConvertDotOp/ConvertDotToFMA.cpp + ConvertDotOp/ConvertDotOpToUkernelOps.cpp Canonicalize.cpp ConvertDotProduct.cpp ConvertUnsupportedOps.cpp diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp index 8fc432b9734e..9ce0c883ff10 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.cpp @@ -29,8 +29,6 @@ bool isLoopCarriedAcc(Value acc) { return false; } - blockArg.getArgNumber(); - Value updAcc = acc.getUsers().begin()->getResult(0); if (!updAcc.hasOneUse()) { LDBG(" No. Has multiple uses."); @@ -199,6 +197,7 @@ MemBuffer findInputBuffer(Value val, bool allowTransposed, bool allowVnni) { LLVM_DEBUG(DBGS() << " Step: "; llvm::interleaveComma(buf.step, llvm::dbgs()); llvm::dbgs() << "\n"); + buf.origBlockPtr = forOp.getTiedLoopInit(blockPtrArg)->get(); return buf; } @@ -221,8 +220,9 @@ Value maybeCast(Location loc, Value val, Type dstElemTy, return rewriter.create(loc, dstTy, val); } -MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, - Operation *allocaPoint, PatternRewriter &rewriter) { +MemBuffer allocateTmpBufferStack(Location loc, VectorType vecTy, + Operation *allocaPoint, + PatternRewriter &rewriter) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(allocaPoint); auto memRefTy = MemRefType::get(vecTy.getShape(), vecTy.getElementType()); @@ -250,6 +250,15 @@ Value shiftIndex(Location loc, Value index, int64_t offs, return rewriter.create(loc, index.getType(), index, offsVal); } +MemBuffer storeToTmpBuffer(Location loc, Value val, Operation *allocaPoint, + PatternRewriter &rewriter) { + LDBG("Storing vector to a temporary buffer: " << val); + auto vecTy = cast(val.getType()); + MemBuffer buf = allocateTmpBufferStack(loc, vecTy, allocaPoint, rewriter); + op_write(val, buf.memRef, buf.indices); + return buf; +} + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h index 2760ebd14fbb..f824994ee556 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotCommon.h @@ -22,6 +22,9 @@ struct MemBuffer { // on each iteration, then step can hold those index offsets. // Empty step doesn't mean indices are loop invariant. SmallVector step; + // When step is known, this field holds the initial block + // pointer value used in the first iteration. + Value origBlockPtr = nullptr; // True if buffer holds transposed value. bool transposed = false; // Ttue if buffer holds value in VNNI (interleaved to groups of 32bit) @@ -70,8 +73,9 @@ Value maybeCast(Location loc, Value val, Type dstElemTy, PatternRewriter &rewriter); // Allocate temporary buffer on stack for specified vector type. -MemBuffer allocateTmpBuffer(Location loc, VectorType vecTy, - Operation *allocaPoint, PatternRewriter &rewriter); +MemBuffer allocateTmpBufferStack(Location loc, VectorType vecTy, + Operation *allocaPoint, + PatternRewriter &rewriter); // Move index by specified offset. Do constannt folding if possible. Value shiftIndex(Location loc, Value index, int64_t offs, @@ -81,6 +85,9 @@ Value shiftIndex(Location loc, Value index, int64_t offs, // If it is, then return the original encoded value. Otherwise, return nullptr. Value getVnniSrc(Value val); +MemBuffer storeToTmpBuffer(Location loc, Value val, Operation *allocaPoint, + PatternRewriter &rewriter); + } // namespace cpu } // namespace triton } // namespace mlir diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotOpToUkernelOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotOpToUkernelOps.cpp new file mode 100644 index 000000000000..52d6f32a1ba6 --- /dev/null +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotOpToUkernelOps.cpp @@ -0,0 +1,480 @@ +#include "ConvertDotCommon.h" + +#include "cpu/include/TritonCPUTransforms/Passes.h" + +#include "cpu/include/Analysis/TensorPtrShapeInfo.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include +#include +#include + +namespace mlir { +namespace triton { +namespace cpu { +#define GEN_PASS_DEF_CONVERTDOTOPTOUKERNELOPS +#include "cpu/include/TritonCPUTransforms/Passes.h.inc" +} // namespace cpu +} // namespace triton +} // namespace mlir + +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +// This structure is used to hold candidates for conversion to ukernel calls. +struct DotOpCandidate { + // Operation to convert. + triton::cpu::DotOp op; + + // Block sizes. + int64_t blockM; + int64_t blockN; + int64_t blockK; + // If accumulator is updated in a loop, then this flag indicates if we + // should keep it in tiles the whole loop and move back to vectors only + // after the loop. + bool isAccLoopCarried = false; + bool canFuseLoop = false; + + // If input data is available in memory then input buffers hold it. + MemBuffer lhsBuf; + MemBuffer rhsBuf; +}; + +bool isLoopInvariant(SmallVector vals, LoopLikeOpInterface loopLike) { + for (Value val : vals) { + LDBG("Checking value for invariance: " << val); + if (!loopLike.isDefinedOutsideOfLoop(val)) { + LDBG(" Not invariant"); + return false; + } + } + return true; +} + +bool checkElemTypes(Type lhsElemTy, Type rhsElemTy, Type accElemTy, + Type resElemTy) { + // Integer types are not supported yet. + if (lhsElemTy.isInteger() || rhsElemTy.isInteger() || resElemTy.isInteger()) { + // Should be also lhs = [u8, s8] rhs = [u8, s8] res = [s32] + // but there is an assertion if res not f32 TODO - verify. + LDBG("Drop candidate. Integer types are not supported."); + return false; + } + + // FP8 input is not supported yet. + if (lhsElemTy.getIntOrFloatBitWidth() == 8 || + rhsElemTy.getIntOrFloatBitWidth() == 8) { + LDBG("Drop candidate. FP8 input is not supported."); + return false; + } + + // FP64 result is not supported. + if (accElemTy.getIntOrFloatBitWidth() == 64 || + resElemTy.getIntOrFloatBitWidth() == 64) { + LDBG("Drop candidate. FP64 result is not supported."); + return false; + } + + return true; +} + +bool checkInputShapes(VectorType lhsTy, VectorType resTy, + DotOpCandidate &candidate) { + if (lhsTy.getRank() != 2) + return false; + + candidate.blockM = resTy.getDimSize(0); + candidate.blockN = resTy.getDimSize(1); + candidate.blockK = lhsTy.getDimSize(1); + + // Todo enable types that require transform (bfloat16, fp16, int8) to have + // block-size (blockN) more than 64 (OneDNN ukernels transform issue) + if (candidate.blockN > 64 && lhsTy.getElementTypeBitWidth() < 32) { + LDBG("Drop candidate. BlockN > 64 && type requires transform. (btw < 32)"); + return false; + } + + return true; +} + +// Check if specified ContractionOp can be lowered to OneDNN ukernel operations. +// If conversion is possible, then true is returned and candidate +// structure is filled with detailed transformation info. +bool isUkernelsCandidate(triton::cpu::DotOp op, DotOpCandidate &candidate) { + VectorType lhsTy = cast(op.getA().getType()); + VectorType rhsTy = cast(op.getB().getType()); + VectorType accTy = cast(op.getC().getType()); + VectorType resTy = cast(op.getType()); + + LDBG("Considering candidate op: " << op); + + if (accTy.getRank() != 2) { + LDBG(" Drop candidate. Only 2D case is supported."); + return false; + } + + // Check input/output types. + if (!checkElemTypes(lhsTy.getElementType(), rhsTy.getElementType(), + accTy.getElementType(), resTy.getElementType())) + return false; + + // Check input shapes. + if (!checkInputShapes(lhsTy, resTy, candidate)) + return false; + + candidate.op = op; + candidate.isAccLoopCarried = isLoopCarriedAcc(op.getC()); + candidate.lhsBuf = findInputBuffer(op.getA(), false); + candidate.rhsBuf = findInputBuffer(op.getB(), false); + + // Check if we can fuse dot op loop into a single brgemm call. + if (candidate.isAccLoopCarried && !candidate.lhsBuf.step.empty() && + !candidate.rhsBuf.step.empty()) { + SmallVector valsToCheckInvariance; + valsToCheckInvariance.append(candidate.lhsBuf.step); + valsToCheckInvariance.append(candidate.rhsBuf.step); + + auto forOp = dyn_cast(op->getParentOp()); + candidate.canFuseLoop = isLoopInvariant(valsToCheckInvariance, forOp); + } + return true; +} + +Value addMemrefSubView(PatternRewriter &rewriter, Location loc, Value vecVal, + ValueRange indices, Value memRef) { + LDBG(" Reusing the original memref for a buffer: " << memRef); + auto vecTy = cast(vecVal.getType()); + auto ctx = rewriter.getContext(); + auto memrefTy = cast(memRef.getType()); + SmallVector strides(memrefTy.getRank(), 1); + SmallVector shape(memrefTy.getRank(), 1); + // we will add 1 to leading dimensions of shapes or just copy existing vector + // shape. + int64_t start_ind = memrefTy.getRank() - vecTy.getRank(); + for (auto ind = 0; ind < vecTy.getRank(); ind++, start_ind++) { + shape[start_ind] = vecTy.getShape()[ind]; + } + + Value memRef_view = rewriter.create( + loc, memRef, getAsOpFoldResult(indices), + getAsIndexOpFoldResult(ctx, shape), getAsIndexOpFoldResult(ctx, strides)); + return memRef_view; +} + +std::pair> +extractBufferFromBlockPtr(Value blockPtr, triton::cpu::DotOp &dotOp, + ModuleTensorPtrShapeInfoAnalysis &shapeAnalysis, + PatternRewriter &rewriter) { + Location loc = dotOp.getLoc(); + MLIRContext *ctx = dotOp.getContext(); + + auto extractMemref = [&](Value ptr) { + auto tensorTy = dyn_cast( + dyn_cast(ptr.getType()).getPointeeType()); + auto elemTy = tensorTy.getElementType(); + auto shapeInfo = shapeAnalysis.getPtrShapeInfo(ptr); + Type memRefTy; + if (shapeInfo && shapeInfo->getRank() > 0) { + auto layout = StridedLayoutAttr::get(ctx, 0, shapeInfo->getStrides()); + memRefTy = MemRefType::get(shapeInfo->getShape(), elemTy, layout); + } else { + SmallVector dynVals(tensorTy.getRank(), ShapedType::kDynamic); + auto layout = StridedLayoutAttr::get(ctx, 0, dynVals); + memRefTy = MemRefType::get(dynVals, elemTy, layout); + } + return rewriter.create(loc, memRefTy, ptr); + }; + + auto memRef = extractMemref(blockPtr); + auto indices = rewriter.create(loc, blockPtr).getResults(); + + return {memRef, indices}; +} + +Value computeStepInBytes(Location loc, memref::ExtractStridedMetadataOp meta, + ArrayRef steps, PatternRewriter &rewriter) { + Value res = index_cst(0); + if (steps.empty()) + return res; + + SmallVector strides = meta.getStrides(); + for (uint i = 0; i < strides.size(); i++) { + LDBG("[compute step]: " << i << "\n\tstride: " << strides[i] + << "\n\tstep: " << steps[i]); + Value stride = strides[i]; + Value step = steps[i]; + if (!step.getType().isIndex()) + step = op_index_cast(rewriter.getIndexType(), step); + Value mul = op_muli(step, stride); + res = op_addi(res, mul); + LDBG("[compute step]: mul " << mul); + } + + Value dtSize = index_cst( + getElementTypeOrSelf(meta.getBaseBuffer()).getIntOrFloatBitWidth() / 8); + res = op_muli(res, dtSize); + return res; +} + +// 1. DotOp is replaced with BRGEMM call (aka loop collapse) +// CONDITIONS: +// - Acc is loop-carried +// - Input buffers should exists, have steps and basic block pointers +// - Buffer steps and block pointers should be loop invariants +// - All generation goes out of the loop +// - The original dot op result uses are replaced with its acc operand (to +// make it dead code) + +// 2. DotOp is replaced with GEMM call +// a) Acc is loop-carried +// - Create buf for acc before the loop +// -- OPT: use output buffer instead of the temporary one +// - Put init acc values into the buf before the loop +// - Load acc from buf after the loop and replace loop result uses with loaded +// acc +// - The original dot op result uses are replaced with its acc operand (to +// make it dead code) +// -- OPT: Remove the original store if output buffer was used for acc + +// b) Acc is not loop-carried +// - Create buf for acc before the loop +// - Put acc value into the buf before the dot op +// - Load acc from buf after GEMM call and replace orig dot op result uses +// with loaded acc +// - The original dot op is removed + +LogicalResult +convertCandidate(DotOpCandidate &candidate, + ModuleTensorPtrShapeInfoAnalysis &shapeInfoAnalysis, + PatternRewriter &rewriter) { + triton::cpu::DotOp op = candidate.op; + Location loc = op.getLoc(); + VectorType resTy = cast(op.getResult().getType()); + Type resElemTy = resTy.getElementType(); + + scf::ForOp forOp = dyn_cast(op->getParentOp()); + Value numBatches = index_cst(1); + if (candidate.isAccLoopCarried && candidate.canFuseLoop) { + // We can fully replace the loop with one op. + + // Initial tile values are loaded before the loop and then directly + // used within the loop. Later, new iter values will be added to + // add loop carried-dependencies for accumulator tiles and accInitTiles + // will be used as initializers for them. + rewriter.setInsertionPoint(forOp); + auto memrefsFromBlockPtr = + extractBufferFromBlockPtr(candidate.lhsBuf.origBlockPtr, candidate.op, + shapeInfoAnalysis, rewriter); + candidate.lhsBuf.memRef = memrefsFromBlockPtr.first; + candidate.lhsBuf.indices = memrefsFromBlockPtr.second; + + memrefsFromBlockPtr = + extractBufferFromBlockPtr(candidate.rhsBuf.origBlockPtr, candidate.op, + shapeInfoAnalysis, rewriter); + candidate.rhsBuf.memRef = memrefsFromBlockPtr.first; + candidate.rhsBuf.indices = memrefsFromBlockPtr.second; + + LDBG("Loading accumulator to tiles before the loop."); + + numBatches = op_divui(op_subi(forOp.getUpperBound(), forOp.getLowerBound()), + forOp.getStep()); + numBatches = op_index_cast(rewriter.getIndexType(), numBatches); + } + + Operation *allocaPoint = op; + while (!isa(allocaPoint->getParentOp())) + allocaPoint = allocaPoint->getParentOp(); + + IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); + auto blockM = int_cst(integer64, candidate.blockM); + auto blockN = int_cst(integer64, candidate.blockN); + auto blockK = int_cst(integer64, candidate.blockK); + + if (candidate.lhsBuf.empty()) { + candidate.lhsBuf = + storeToTmpBuffer(loc, candidate.op.getA(), allocaPoint, rewriter); + } + + if (candidate.rhsBuf.empty()) { + candidate.rhsBuf = + storeToTmpBuffer(loc, candidate.op.getB(), allocaPoint, rewriter); + } + + Value accToStore = op.getC(); + if (candidate.isAccLoopCarried) { + LDBG("Setting insertion op to forOp. (accToStore)"); + forOp = cast(op->getParentOp()); + accToStore = getInitAccValue(accToStore); + } + + MemBuffer accBuf; + { + // If accumulator is bufferized then we should move initial values before + // the loop. + OpBuilder::InsertionGuard g(rewriter); + if (candidate.isAccLoopCarried) { + LDBG("String Setting insertion op to forOp. (accBuf)"); + rewriter.setInsertionPoint(forOp); + } + // Currently, acc always needs to be FP32. + accToStore = maybeCast(loc, accToStore, rewriter.getF32Type(), rewriter); + accBuf = storeToTmpBuffer(loc, accToStore, allocaPoint, rewriter); + } + + auto lhsSubView = + addMemrefSubView(rewriter, loc, candidate.op.getA(), + candidate.lhsBuf.indices, candidate.lhsBuf.memRef); + auto rhsSubView = + addMemrefSubView(rewriter, loc, candidate.op.getB(), + candidate.rhsBuf.indices, candidate.rhsBuf.memRef); + + auto metadataA = + rewriter.create(loc, lhsSubView); + auto metadataB = + rewriter.create(loc, rhsSubView); + auto metadataAcc = + rewriter.create(loc, accBuf.memRef); + + Value lda = metadataA.getStrides()[metadataA.getStrides().size() - 2]; + Value ldb = metadataB.getStrides()[metadataB.getStrides().size() - 2]; + Value ldc = metadataAcc.getStrides()[metadataAcc.getStrides().size() - 2]; + + Value brgemm = rewriter.create( + loc, rewriter.getIndexType(), blockM, blockN, blockK, numBatches, lda, + ldb, ldc, op.getA().getType(), op.getB().getType(), + rewriter.getF32Type()); + + auto rhsTypeSize = + int_cst(integer64, op.getB().getType().getElementTypeBitWidth() / 8); + Value rhsBlockSizeInBytes = op_muli(op_muli(blockN, blockK), rhsTypeSize); + + LDBG("[prepareResultBuffer] prepared acc buf: " << accBuf.memRef); + LDBG("lhsBuf: { memref " + << candidate.lhsBuf.memRef << "\n " + << " indices " << candidate.lhsBuf.indices.size() << "\n" + << " step " << candidate.lhsBuf.step.size() << "\n" + << " blockptr " << candidate.lhsBuf.origBlockPtr << "\n" + << " transposed " << candidate.lhsBuf.transposed << "\n} \n"); + LDBG("rhsBuf: { memref " + << candidate.rhsBuf.memRef << "\n " + << " indices " << candidate.rhsBuf.indices.size() << "\n" + << " step " << candidate.rhsBuf.step.size() << "\n" + << " blockptr " << candidate.rhsBuf.origBlockPtr << "\n" + << " transposed " << candidate.rhsBuf.transposed << "\n} \n"); + + Value lhsStepInBytes = + computeStepInBytes(loc, metadataA, candidate.lhsBuf.step, rewriter); + Value rhsStepInBytes = + computeStepInBytes(loc, metadataB, candidate.rhsBuf.step, rewriter); + + rewriter.create( + loc, brgemm, lhsSubView, rhsSubView, accBuf.memRef, lhsStepInBytes, + rhsStepInBytes, rhsBlockSizeInBytes, numBatches); + + if (candidate.isAccLoopCarried && candidate.canFuseLoop) { + LDBG("Loading the result to a vector to replace orig op result."); + Value newVal = + op_read(cast(toFp32(resTy)), accBuf.memRef, accBuf.indices); + + // Hope that dead code elemination do the rest. + rewriter.replaceOp(candidate.op, candidate.op.getC()); + + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, resElemTy, rewriter); + rewriter.replaceAllOpUsesWith( + forOp, + ValueRange{newVal, candidate.lhsBuf.memRef, candidate.rhsBuf.memRef}); + return success(); + } + + if (candidate.isAccLoopCarried) { + rewriter.setInsertionPointAfter(forOp); + auto rank = dyn_cast(accBuf.memRef.getType()).getRank(); + SmallVector inBounds(rank, false); + Value newVal = + op_read(cast(toFp32(resTy)), accBuf.memRef, accBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, resElemTy, rewriter); + int resIdx = op.getResult().getUses().begin()->getOperandNumber(); + Value loopRes = forOp.getResult(resIdx); + loopRes.replaceAllUsesWith(newVal); + rewriter.replaceOp(op, op.getC()); + return success(); + } + LDBG("Loading the result to a vector to replace orig op result."); + Value newVal = rewriter.create( + loc, cast(toFp32(resTy)), accBuf.memRef, accBuf.indices); + // We might need to cast back to the original type. + newVal = maybeCast(loc, newVal, resElemTy, rewriter); + op.getResult().replaceAllUsesWith(newVal); + rewriter.eraseOp(op); + return success(); +} + +struct ConvertDotOpToUkernelOps + : public triton::cpu::impl::ConvertDotOpToUkernelOpsBase< + ConvertDotOpToUkernelOps> { + ConvertDotOpToUkernelOps() = default; + ConvertDotOpToUkernelOps(Ukernels ukernels) { this->ukernels = ukernels; } + + void runOnOperation() override { + if (ukernels != Ukernels::OneDNN) { + LDBG("Pass disabled."); + return; + } + + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + ModuleTensorPtrShapeInfoAnalysis shapeInfoAnalysis(mod); + + SmallVector candidates; + mod->walk([&candidates](triton::cpu::DotOp op) { + DotOpCandidate candidate; + if (isUkernelsCandidate(op, candidate)) { + LLVM_DEBUG({ + LDBG("Found OneDNN candidate"); + LDBG(" Op: " << candidate.op); + LDBG(" blockM: " << candidate.blockM); + LDBG(" blockN: " << candidate.blockN); + LDBG(" blockK: " << candidate.blockK); + LDBG(" isAccLoopCarried: " << candidate.isAccLoopCarried); + LDBG(" canFuseLoop: " << candidate.canFuseLoop); + }); + candidates.push_back(candidate); + } + return WalkResult::advance(); + }); + + for (auto &candidate : candidates) { + LDBG("Starting conversion of candidate: " << candidate.op); + PatternRewriter rewriter(context); + rewriter.setInsertionPoint(candidate.op); + if (succeeded(convertCandidate(candidate, shapeInfoAnalysis, rewriter))) { + LDBG("Conversion succeeded!"); + } else { + LDBG("Conversion failed!"); + } + } + } +}; + +} // namespace + +namespace mlir::triton::cpu { + +std::unique_ptr> +createConvertDotOpToUkernelOps(Ukernels ukernels) { + return std::make_unique(ukernels); +} + +} // namespace mlir::triton::cpu diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp index 11ce852e7570..1e955a5fbdd5 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToAMX.cpp @@ -213,18 +213,6 @@ void setupBlockAndTileSizes(ArrayRef lhsShape, candidate.tilesInBlockN = accBlocksN; } -// Check if vector transfer read/write operation uses a mask -// or involves a bounds check. -template bool hasMaskOrBoundsCheck(T op) { - auto inBounds = op.getInBounds(); - Value mask = op.getMask(); - bool hasBoundsCheck = - std::any_of(inBounds.begin(), inBounds.end(), [](Attribute attr) { - return !cast(attr).getValue(); - }); - return hasBoundsCheck || mask; -} - // Check if a value is used only for a store and that this store can be // replaced with tile stores. In this case fill appropriate fields in the // candidate structure. @@ -407,8 +395,8 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, if (interleave && !inputBuf.vnni) { LDBG(" Copying from the original memref with interleave: " << inputBuf.memRef); - auto tmpBuf = allocateTmpBuffer(loc, getSwizzledRhsTileType(vecTy), - allocaPoint, rewriter); + auto tmpBuf = allocateTmpBufferStack(loc, getSwizzledRhsTileType(vecTy), + allocaPoint, rewriter); copyWithInterleave(loc, vecTy, inputBuf, tmpBuf, rewriter); return tmpBuf; } @@ -423,7 +411,7 @@ MemBuffer prepareTensorBuffer(Location loc, Value val, bool interleave, if (interleave) vecTy = getSwizzledRhsTileType(vecTy); - MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); + MemBuffer buf = allocateTmpBufferStack(loc, vecTy, allocaPoint, rewriter); if (interleave) { auto interleavedVal = getVnniSrc(val); @@ -456,8 +444,8 @@ MemBuffer prepareResultBuffer(Location loc, Value val, const MemBuffer &accBuf, } LDBG("Allocating buffer for the result."); - return allocateTmpBuffer(loc, cast(val.getType()), allocaPoint, - rewriter); + return allocateTmpBufferStack(loc, cast(val.getType()), + allocaPoint, rewriter); } SmallVector shiftIndices(Location loc, ArrayRef indices, diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp index 4d1832ca8cf9..76ec529729ed 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotOp/ConvertDotToFMA.cpp @@ -124,15 +124,6 @@ bool isFmaCandidate(cpu::DotOp op, FmaDotOpCandidate &candidate) { return true; } -MemBuffer storeToTmpBuffer(Location loc, Value val, Operation *allocaPoint, - PatternRewriter &rewriter) { - LDBG("Storing vector to a temporary buffer: " << val); - auto vecTy = cast(val.getType()); - MemBuffer buf = allocateTmpBuffer(loc, vecTy, allocaPoint, rewriter); - rewriter.create(loc, val, buf.memRef, buf.indices); - return buf; -} - SmallVector shiftIndices(Location loc, ArrayRef indices, bool transposed, int64_t m, int64_t n, PatternRewriter &rewriter) { diff --git a/third_party/cpu/runtime/runtime_onednn.cpp b/third_party/cpu/runtime/runtime_onednn.cpp new file mode 100644 index 000000000000..1c3ce11ea939 --- /dev/null +++ b/third_party/cpu/runtime/runtime_onednn.cpp @@ -0,0 +1,155 @@ +#if defined(ONEDNN_AVAILABLE) +#include "oneapi/dnnl/dnnl_types.h" +#include "oneapi/dnnl/dnnl_ukernel.hpp" +#include "oneapi/dnnl/dnnl_ukernel_types.h" +#if !defined(DNNL_EXPERIMENTAL_UKERNEL) +#error "DNNL Ukerenel ismissing" +#endif +#endif + +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#define EXPORT __declspec(dllexport) +#elif defined(__GNUC__) +#define EXPORT __attribute__((visibility("default"))) +#else +#define EXPORT +#endif + +#if defined(ONEDNN_AVAILABLE) +using namespace dnnl; +using namespace dnnl::ukernel; +#endif + +using read_lock_guard_t = std::shared_lock; +using write_lock_guard_t = std::unique_lock; +static std::shared_mutex g_brgemm_lock; + +extern "C" { + +struct onednn_handle { + dnnl::ukernel::transform transform; + dnnl::ukernel::brgemm brg; +}; + +EXPORT void *create_brgemm(int64_t M, int64_t N, int64_t K_k, + int64_t batch_size, int64_t lda, int64_t ldb, + int64_t ldc, int64_t dtypeA, int64_t dtypeB, + int64_t dtypeC) { + using KeyT = std::array; + KeyT key{M, N, K_k, batch_size, lda, ldb, ldc, dtypeA, dtypeB, dtypeC}; + + static std::map savedUkernels; + { + read_lock_guard_t r_g(g_brgemm_lock); + if (savedUkernels.count(key) != 0) { + return &savedUkernels.find(key)->second; + } + } + + write_lock_guard_t w_g(g_brgemm_lock); + + if (savedUkernels.count(key) != 0) { + return &savedUkernels.find(key)->second; + } + + auto dnnl_dtypeA = static_cast(dtypeA); + auto dnnl_dtypeB = static_cast(dtypeB); + auto dnnl_dtypeC = static_cast(dtypeC); + + dnnl::ukernel::brgemm brg; + brg = dnnl::ukernel::brgemm(M, N, K_k, batch_size, lda, ldb, ldc, dnnl_dtypeA, + dnnl_dtypeB, dnnl_dtypeC); + // Instruct the kernel to append the result to C tensor. + brg.set_add_C(true); + // Finalize the initialization. + brg.finalize(); + + bool need_packing = brg.get_B_pack_type() == pack_type::pack32; + if (need_packing) { + brg = dnnl::ukernel::brgemm(M, N, K_k, batch_size, lda, N, ldc, dnnl_dtypeA, + dnnl_dtypeB, dnnl_dtypeC); + // Instruct the kernel to append the result to C tensor. + brg.set_add_C(true); + // Finalize the initialization. + brg.finalize(); + } + + // Generate the executable JIT code for the objects. + brg.generate(); + + dnnl::ukernel::transform tf; + if (need_packing) { + // Packing B tensor routine. The BRGeMM ukernel expects B passed in a + // special VNNI format for low precision data types, e.g., bfloat16_t. + // Note: the routine doesn't provide a `batch_size` argument in the + // constructor as it can be either incorporated into `K` dimension, or + // manually iterated over in a for-loop on the user side. + dnnl::ukernel::transform pack_B( + /* K = */ K_k, /* N = */ N, /* in_pack_type = */ pack_type::no_trans, + /* in_ld = */ ldb, /* out_ld = */ N, /* in_dt = */ dnnl_dtypeB, + /* out_dt = */ dnnl_dtypeB); + + pack_B.generate(); + tf = std::move(pack_B); + } + + auto it = savedUkernels.insert({key, {tf, brg}}); + return &it.first->second; +} + +EXPORT void brgemm_execute(const void *handle, void *A_ptr, + void *original_B_ptr, void *C_ptr, + int64_t A_step_in_bytes, int64_t B_step_in_bytes, + int64_t B_block_size_in_bytes, int64_t num_batches) { + + uint8_t *blocked_data = reinterpret_cast(original_B_ptr); + const uint8_t *B_ptr_calc = reinterpret_cast(original_B_ptr); + + const onednn_handle *kernel = reinterpret_cast(handle); + + const auto pack_B = kernel->transform; + const auto brg = kernel->brg; + + const bool need_packing = brg.get_B_pack_type() == pack_type::pack32; + if (need_packing) { + blocked_data = new uint8_t[B_block_size_in_bytes * num_batches]; + } + + brg.set_hw_context(); + + std::vector> A_B_offsets(num_batches); + for (memory::dim i = 0; i < num_batches; i++) { + const memory::dim A_offset_i = i * A_step_in_bytes; + + memory::dim B_offset_i; + if (need_packing) { + pack_B.execute(B_ptr_calc + i * B_step_in_bytes, + blocked_data + i * B_block_size_in_bytes); + B_offset_i = i * B_block_size_in_bytes; + } else { + B_offset_i = i * B_step_in_bytes; + } + A_B_offsets[i] = std::make_pair(A_offset_i, B_offset_i); + } + + size_t scratchpad_size = brg.get_scratchpad_size(); + std::vector scratchpad_sm(scratchpad_size); + // An execute call. `A_B` is a vector of pointers to A and packed B + // tensors. `acc_ptr` is a pointer to an accumulator buffer. + brg.execute(A_ptr, blocked_data, A_B_offsets, C_ptr, scratchpad_sm.data()); + + dnnl::ukernel::brgemm::release_hw_context(); + + if (need_packing) { + delete[] blocked_data; + }; +} + +} // extern C diff --git a/third_party/cpu/triton_cpu.cc b/third_party/cpu/triton_cpu.cc index 50159cd94d33..a409099b15d1 100644 --- a/third_party/cpu/triton_cpu.cc +++ b/third_party/cpu/triton_cpu.cc @@ -8,11 +8,13 @@ #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/Passes.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" +#include "mlir/Transforms/Passes.h" #include "llvm/IR/Constants.h" #include "llvm/Support/TargetSelect.h" @@ -26,6 +28,17 @@ #include #include +#ifdef ONEDNN_AVAILABLE +#include "oneapi/dnnl/dnnl_config.h" +#endif +bool is_onednn_available() { +#ifdef DNNL_EXPERIMENTAL_UKERNEL + return true; +#else + return false; +#endif +} + namespace py = pybind11; void init_triton_cpu_passes_ttcpuir(py::module &&m) { @@ -35,6 +48,9 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { .value("libsleef", cpu::VecLib::Sleef) .value("libmvec", cpu::VecLib::Mvec); + py::enum_(m, "Ukernels") + .value("OneDNN", cpu::Ukernels::OneDNN); + m.def("add_scalarize", [](mlir::PassManager &pm, bool skip_gather_scatter) { pm.addPass( mlir::triton::cpu::createScalarizeUsingForOpPass(skip_gather_scatter)); @@ -86,6 +102,13 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { bool useHorizontalSum) { pm.addPass(mlir::triton::cpu::createConvertDotProduct(useHorizontalSum)); }); + m.def("add_loop_invariant_code_motion", [](mlir::PassManager &pm) { + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + }); + m.def("add_convert_dot_to_ukernels", [](mlir::PassManager &pm, + cpu::Ukernels ukernels) { + pm.addPass(mlir::triton::cpu::createConvertDotOpToUkernelOps(ukernels)); + }); m.def("add_convert_dot_to_amx", [](mlir::PassManager &pm, bool convertInt8, bool convertFp16, bool convertBf16) { pm.addPass(mlir::triton::cpu::createConvertDotToAMX( @@ -137,6 +160,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) { m.def("add_debug_ops_to_llvmir", [](mlir::PassManager &pm) { pm.addPass(mlir::triton::cpu::createDebugOpsToLLVMPass()); }); + m.def("add_ukernels_to_onednn_llvmir", [](mlir::PassManager &pm) { + pm.addPass(mlir::triton::cpu::createUkernelOpsToOneDNNLLVMPass()); + }); + m.def("add_expand_strided_metadata", [](mlir::PassManager &pm) { + pm.addPass(mlir::memref::createExpandStridedMetadataPass()); + }); m.def("add_vector_to_llvmir", [](mlir::PassManager &pm, bool reassoc_fp_reduction) { mlir::ConvertVectorToLLVMPassOptions opts; @@ -186,6 +215,8 @@ void init_triton_cpu(py::module &&m) { #endif // __linux__ && ARCH_REQ_XCOMP_PERM }); + m.def("onednn_available", is_onednn_available); + m.def("load_dialects", [](mlir::MLIRContext &context) { mlir::DialectRegistry registry; registry.insert Date: Wed, 19 Feb 2025 04:44:52 +0100 Subject: [PATCH 162/165] Add missing headers toruntime (#215) This commits adds missing headers to runtime files. Resolves: #180 Signed-off-by: Dmitrii Makarenko --- third_party/cpu/runtime/cpu_runtime.cpp | 1 + third_party/cpu/runtime/runtime_onednn.cpp | 2 ++ 2 files changed, 3 insertions(+) diff --git a/third_party/cpu/runtime/cpu_runtime.cpp b/third_party/cpu/runtime/cpu_runtime.cpp index 68b7efa78f01..6d13ae127920 100644 --- a/third_party/cpu/runtime/cpu_runtime.cpp +++ b/third_party/cpu/runtime/cpu_runtime.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #define __STDC_WANT_IEC_60559_TYPES_EXT__ diff --git a/third_party/cpu/runtime/runtime_onednn.cpp b/third_party/cpu/runtime/runtime_onednn.cpp index 1c3ce11ea939..5d2459e66d3d 100644 --- a/third_party/cpu/runtime/runtime_onednn.cpp +++ b/third_party/cpu/runtime/runtime_onednn.cpp @@ -7,12 +7,14 @@ #endif #endif +#include #include #include #include #include #include #include +#include #if defined(_MSC_VER) #define EXPORT __declspec(dllexport) From f46cf957358c2b19fe28faa6eea973232bf07088 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Thu, 20 Feb 2025 18:38:20 +0000 Subject: [PATCH 163/165] Allign with new LLVM version and remove deprecated calls. Signed-off-by: Dmitrii Makarenko --- .../Dialect/TritonCPU/IR/TritonCPUTypes.td | 2 +- .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 26 ++++++++------ .../cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp | 3 +- .../TritonCPUToLLVM/LowerMultiReduction.cpp | 2 +- .../cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp | 5 +-- .../lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp | 11 +++--- .../UkernelOpsToOneDNNLLVM.cpp | 7 ++-- .../lib/TritonCPUTransforms/Canonicalize.cpp | 2 +- .../TritonCPUTransforms/ConvertDotProduct.cpp | 2 +- .../ConvertUnsupportedOps.cpp | 2 +- .../DecomposeFpConversions.cpp | 34 +++++++++---------- .../lib/TritonCPUTransforms/OptimizeMasks.cpp | 2 +- .../ScalarizeUsingForOps.cpp | 2 +- 13 files changed, 55 insertions(+), 45 deletions(-) diff --git a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td index d6ac013804c8..e35edaabe0d1 100644 --- a/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td +++ b/include/triton/Dialect/TritonCPU/IR/TritonCPUTypes.td @@ -24,7 +24,7 @@ def TTC_TokenType : TTC_TypeDef<"Token", "token"> { let skipDefaultBuilders = 1; } -def TTC_Vector : VectorOf<[TT_Float, TT_Int]>; +def TTC_Vector : VectorOfAnyRankOf<[TT_Float, TT_Int]>; def TTC_Type : AnyTypeOf<[TT_Float, TT_Int, TTC_Vector]>; diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index 33e1753e31b2..f5ff91a43327 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -73,16 +73,17 @@ Value printfPromoteValue(RewriterBase &rewriter, Value value) { auto *context = rewriter.getContext(); auto type = value.getType(); auto loc = UnknownLoc::get(context); + auto b = TritonLLVMOpBuilder(loc, rewriter); bool isUnsigned = type.isUnsignedInteger(); if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) { if (isUnsigned) { - return zext(ui32_ty, value); + return b.zext(ui32_ty, value); } else { - return sext(i32_ty, value); + return b.sext(i32_ty, value); } } else if (type.isBF16() || type.isF16() || type.isF32()) { - return fpext(f64_ty, value); + return b.fpext(f64_ty, value); } return value; @@ -161,6 +162,7 @@ void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, bool isSigned = false) { assert(!prefix.empty() && "printf with empty string not supported"); auto loc = UnknownLoc::get(rewriter.getContext()); + auto b = TritonLLVMOpBuilder(loc, rewriter); std::string formatStr; llvm::raw_string_ostream os(formatStr); @@ -180,7 +182,7 @@ void createRuntimePrintScalarCall(ConversionPatternRewriter &rewriter, allArgs.push_back(elem); if (arg.has_value()) allArgs.push_back(printfPromoteValue(rewriter, arg.value())); - call(getOrAddPrintFuncDecl(rewriter, true), allArgs); + b.call(getOrAddPrintFuncDecl(rewriter, true), allArgs); } void createRuntimePrintCall(ConversionPatternRewriter &rewriter, @@ -188,6 +190,7 @@ void createRuntimePrintCall(ConversionPatternRewriter &rewriter, Value ptr, Type dtype, bool isSigned, bool hex) { assert(!prefix.empty()); auto loc = UnknownLoc::get(rewriter.getContext()); + auto b = TritonLLVMOpBuilder(loc, rewriter); Value prefixValue = LLVM::addStringToModule( loc, rewriter, "vectorPrintPrefix_", makeNullTerminatedString(prefix)); @@ -198,12 +201,12 @@ void createRuntimePrintCall(ConversionPatternRewriter &rewriter, allArgs.push_back(prefixValue); allArgs.push_back(ptr); - allArgs.push_back(i32_val(dtype.getIntOrFloatBitWidth())); - allArgs.push_back(i32_val(dtype.isInteger())); - allArgs.push_back(i32_val(isSigned)); - allArgs.push_back(i32_val(hex)); + allArgs.push_back(b.i32_val(dtype.getIntOrFloatBitWidth())); + allArgs.push_back(b.i32_val(dtype.isInteger())); + allArgs.push_back(b.i32_val(isSigned)); + allArgs.push_back(b.i32_val(hex)); - call(getOrAddPrintMemrefFuncDecl(rewriter), allArgs); + b.call(getOrAddPrintMemrefFuncDecl(rewriter), allArgs); } bool usePrintf(triton::cpu::PrintOp op) { @@ -266,6 +269,7 @@ struct AssertOpConversion ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto ctx = rewriter.getContext(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto typeConverter = getTypeConverter(); Value message = LLVM::addStringToModule(loc, rewriter, "assertMessage_", @@ -292,8 +296,8 @@ struct AssertOpConversion makeNullTerminatedString(funcStr)); SmallVector args{getPid(op, 0), getPid(op, 1), getPid(op, 2), op.getCondition(), message, file, - i32_val(line), func}; - call(getAssertFuncDecl(rewriter), args); + b.i32_val(line), func}; + b.call(getAssertFuncDecl(rewriter), args); rewriter.eraseOp(op); return success(); } diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp index 99962da6546a..8ff02cf1fa70 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -155,9 +155,10 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern { Value packedResults = rewriter.create(op.getLoc(), packedResultsTy); auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); for (auto it : llvm::enumerate(adaptor.getOperands())) { packedResults = - insert_val(packedResultsTy, packedResults, it.value(), it.index()); + b.insert_val(packedResultsTy, packedResults, it.value(), it.index()); } newOp = rewriter.create(op.getLoc(), packedResults); } diff --git a/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp b/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp index 74f81cb0f9cc..2a2f1ad0d3b4 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/LowerMultiReduction.cpp @@ -41,7 +41,7 @@ struct LowerMultiReduction vector::populateVectorMultiReductionLoweringPatterns(loweringPatterns, options); - if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) + if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp index 2b1877c1c17b..aeaddef38eb8 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MathToVecLib.cpp @@ -157,6 +157,7 @@ struct PadSmallVecsForSleef : public OpRewritePattern { LogicalResult matchAndRewrite(ExternElementwiseOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); VectorType vecTy = dyn_cast(op.getType()); if (!vecTy) return failure(); @@ -171,7 +172,7 @@ struct PadSmallVecsForSleef : public OpRewritePattern { // Create a single-element vector for shuffle to use auto paddingVec = rewriter.create( - loc, undef(elemTy), VectorType::get({1}, elemTy)); + loc, b.undef(elemTy), VectorType::get({1}, elemTy)); // Assign indices such that shuffle will pad the original vector with // elements from the paddingVec SmallVector indices(4); @@ -397,7 +398,7 @@ struct MathToVecLibPass patterns.add(patterns.getContext()); patterns.add(patterns.getContext()); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) signalPassFailure(); } diff --git a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp index a3fbf20a713e..3007dfc8e53e 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/MemoryOpToLLVM.cpp @@ -50,6 +50,7 @@ struct ExtractMemRefOpConversion : public OpConversionPattern { matchAndRewrite(ExtractMemRefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Value tensorPtrStruct = rewriter.getRemappedValue(op.getSrc()); auto memRefTy = cast(op.getType()); auto rank = memRefTy.getRank(); @@ -66,12 +67,12 @@ struct ExtractMemRefOpConversion : public OpConversionPattern { idxTo); }; - Value res = undef(memRefStructTy); + Value res = b.undef(memRefStructTy); // Copy base. res = copyValue(res, 0, 1); // Use 0 offset. res = rewriter.create(loc, memRefStructTy, res, - i64_val(0), 2); + b.i64_val(0), 2); // Copy shape. res = copyValue(res, 2, 3); // Copy strides. @@ -115,10 +116,11 @@ struct PtrToMemRefOpConversion : public OpConversionPattern { matchAndRewrite(PtrToMemRefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); Value ptr = rewriter.getRemappedValue(op.getSrc()); auto memRefStructTy = getTypeConverter()->convertType(op.getType()); - Value res = undef(memRefStructTy); + Value res = b.undef(memRefStructTy); res = rewriter.create(loc, memRefStructTy, res, ptr, 1); rewriter.replaceOp(op, res); @@ -134,6 +136,7 @@ struct MakeTensorPtrOpConversion : public OpConversionPattern { matchAndRewrite(MakeTensorPtrOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); auto structTy = getTypeConverter()->convertType(op.getType()); auto i64Ty = IntegerType::get(getContext(), 64); @@ -149,7 +152,7 @@ struct MakeTensorPtrOpConversion : public OpConversionPattern { return structVal; }; - Value res = undef(structTy); + Value res = b.undef(structTy); // 0 - base pointer. auto base = rewriter.getRemappedValue(op.getBase()); res = rewriter.create(loc, structTy, res, base, 0); diff --git a/third_party/cpu/lib/TritonCPUToLLVM/UkernelOpsToOneDNNLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/UkernelOpsToOneDNNLLVM.cpp index dc56651e968b..7decd78ebf44 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/UkernelOpsToOneDNNLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/UkernelOpsToOneDNNLLVM.cpp @@ -99,13 +99,14 @@ struct BrgemmCreateConversion : public ConvertOpToLLVMPattern { matchAndRewrite(BrgemmCreate brgemmOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = brgemmOp.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); std::string dispatchName = "create_brgemm"; - auto lhsDnnType = i64_val(getDnnlDataTypeVal(adaptor.getDtypeA())); - auto rhsDnnType = i64_val(getDnnlDataTypeVal(adaptor.getDtypeB())); - auto accDnnType = i64_val(getDnnlDataTypeVal(adaptor.getDtypeC())); + auto lhsDnnType = b.i64_val(getDnnlDataTypeVal(adaptor.getDtypeA())); + auto rhsDnnType = b.i64_val(getDnnlDataTypeVal(adaptor.getDtypeB())); + auto accDnnType = b.i64_val(getDnnlDataTypeVal(adaptor.getDtypeC())); auto brgemmArgs = SmallVector{adaptor.getM(), adaptor.getN(), diff --git a/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp b/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp index 65fed92d2b50..8123766b7b42 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/Canonicalize.cpp @@ -90,7 +90,7 @@ struct Canonicalize : public triton::cpu::impl::CanonicalizeBase { RewritePatternSet patterns(context); patterns.add(context); - if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + if (failed(mlir::applyPatternsGreedily(mod, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp index da96eea967cd..e4ac63076cea 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertDotProduct.cpp @@ -467,7 +467,7 @@ struct ConvertDotProduct patterns.add(context); } - if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + if (failed(mlir::applyPatternsGreedily(mod, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp index 06ad1f1f6802..b232c1f73418 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp @@ -436,7 +436,7 @@ struct ConvertUnsupportedOps patterns.add>(context); } - if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + if (failed(mlir::applyPatternsGreedily(mod, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp index 4a4c8bd8e448..d4e5a646bf39 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -222,13 +222,13 @@ Value convertToFp8(Location loc, Value src, Type dstFpTy, int dstExpBits, Value convertFp16ToFp8E4M3Rtz(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, false, + return convertToFp8(loc, src, rewriter.getType(), 4, 7, false, false, rewriter); } Value convertFp16ToFp8E4M3Rtne(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, true, + return convertToFp8(loc, src, rewriter.getType(), 4, 7, true, false, rewriter); } @@ -271,7 +271,7 @@ Value convertFp16ToFp8E5M2B16Rtz(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + return convertToFp8(loc, f32Src, rewriter.getType(), 5, 16, false, true, rewriter); } @@ -279,7 +279,7 @@ Value convertFp16ToFp8E5M2B16Rtne(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + return convertToFp8(loc, f32Src, rewriter.getType(), 5, 16, true, true, rewriter); } @@ -287,7 +287,7 @@ Value convertBf16ToFp8E4M3Rtz(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNType(), 4, 7, false, + return convertToFp8(loc, f32Src, rewriter.getType(), 4, 7, false, false, rewriter); } @@ -295,7 +295,7 @@ Value convertBf16ToFp8E4M3Rtne(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E4M3FNType(), 4, 7, true, + return convertToFp8(loc, f32Src, rewriter.getType(), 4, 7, true, false, rewriter); } @@ -303,7 +303,7 @@ Value convertBf16ToFp8E5M2Rtz(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2Type(), 5, 15, false, + return convertToFp8(loc, f32Src, rewriter.getType(), 5, 15, false, false, rewriter); } @@ -311,7 +311,7 @@ Value convertBf16ToFp8E5M2Rtne(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2Type(), 5, 15, true, + return convertToFp8(loc, f32Src, rewriter.getType(), 5, 15, true, false, rewriter); } @@ -319,7 +319,7 @@ Value convertBf16ToFp8E5M2B16Rtz(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + return convertToFp8(loc, f32Src, rewriter.getType(), 5, 16, false, true, rewriter); } @@ -327,43 +327,43 @@ Value convertBf16ToFp8E5M2B16Rtne(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getFloat8E5M2FNUZType(), 5, 16, + return convertToFp8(loc, f32Src, rewriter.getType(), 5, 16, true, true, rewriter); } Value convertFp32ToFp8E4M3Rtz(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, false, + return convertToFp8(loc, src, rewriter.getType(), 4, 7, false, false, rewriter); } Value convertFp32ToFp8E4M3Rtne(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getFloat8E4M3FNType(), 4, 7, true, + return convertToFp8(loc, src, rewriter.getType(), 4, 7, true, false, rewriter); } Value convertFp32ToFp8E5M2Rtz(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getFloat8E5M2Type(), 5, 15, false, + return convertToFp8(loc, src, rewriter.getType(), 5, 15, false, false, rewriter); } Value convertFp32ToFp8E5M2Rtne(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getFloat8E5M2Type(), 5, 15, true, + return convertToFp8(loc, src, rewriter.getType(), 5, 15, true, false, rewriter); } Value convertFp32ToFp8E5M2B16Rtz(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getFloat8E5M2FNUZType(), 5, 16, false, + return convertToFp8(loc, src, rewriter.getType(), 5, 16, false, true, rewriter); } Value convertFp32ToFp8E5M2B16Rtne(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getFloat8E5M2FNUZType(), 5, 16, true, + return convertToFp8(loc, src, rewriter.getType(), 5, 16, true, true, rewriter); } @@ -520,7 +520,7 @@ struct DecomposeFpConversions patterns.add(context); } - if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + if (failed(mlir::applyPatternsGreedily(mod, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp index e747ef16c957..dccd6a2321d5 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/OptimizeMasks.cpp @@ -351,7 +351,7 @@ struct OptimizeMasks patterns.add(context); patterns.add(context); patterns.add(context); - if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)))) + if (failed(mlir::applyPatternsGreedily(mod, std::move(patterns)))) return signalPassFailure(); // TODO: if masks removal failed for loads/stores in a for-loop, we might diff --git a/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp index 0e8102831e1e..ff389c174aed 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ScalarizeUsingForOps.cpp @@ -381,7 +381,7 @@ struct ScalarizeUsingForOpPass ScalarizeOpConversion>( axisInfoAnalysis, context, skipGatherScatter); - if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) { + if (applyPatternsGreedily(mod, std::move(patterns)).failed()) { return signalPassFailure(); } } From 5f23d7b50067bc2daf914816166addc613dbaf36 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Fri, 21 Feb 2025 12:20:22 +0000 Subject: [PATCH 164/165] rebase fixes --- third_party/cpu/backend/compiler.py | 2 +- third_party/cpu/backend/driver.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/third_party/cpu/backend/compiler.py b/third_party/cpu/backend/compiler.py index 6736976cc168..cff0f55722dc 100644 --- a/third_party/cpu/backend/compiler.py +++ b/third_party/cpu/backend/compiler.py @@ -127,7 +127,7 @@ def parse_options(self, opts) -> Any: def pack_metadata(self, metadata): return metadata - def get_codegen_implementation(self): + def get_codegen_implementation(self, options): codegen_fns = {"min_dot_size": min_dot_size(self.target)} return codegen_fns diff --git a/third_party/cpu/backend/driver.py b/third_party/cpu/backend/driver.py index 3308fd23c680..8ef6ac35cb6f 100644 --- a/third_party/cpu/backend/driver.py +++ b/third_party/cpu/backend/driver.py @@ -435,6 +435,10 @@ def __init__(self): def get_current_device(self): return 0 + def get_active_torch_device(self): + import torch + return torch.device("cpu", self.get_current_device()) + def get_current_stream(self, device): return 0 From 95e6cbed92b290014e00c4c3ae0678eab9411796 Mon Sep 17 00:00:00 2001 From: Dmitrii Makarenko Date: Fri, 21 Feb 2025 13:33:27 +0000 Subject: [PATCH 165/165] rebase issues --- python/test/unit/language/print_helper.py | 2 +- python/test/unit/language/test_core.py | 1 + python/test/unit/language/test_subprocess.py | 2 - python/triton/runtime/build.py | 28 -------- python/triton/runtime/jit.py | 6 +- python/triton/testing.py | 2 +- python/tutorials/01-vector-add.py | 3 +- .../lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp | 2 +- .../cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp | 4 +- .../DecomposeFpConversions.cpp | 64 +++++++++---------- 10 files changed, 44 insertions(+), 70 deletions(-) diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index 07cc1cc7223c..15d71e5e6f58 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -101,7 +101,7 @@ def kernel_print_2d_tensor(X, Y, BLOCK_SIZE_X: tl.constexpr, BLOCK_SIZE_Y: tl.co def test_print(func: str, data_type: str, device: str): - N = 128 # This value should match with test_print in test_subprocess.py. + N = 128 # This value should match with test_print in test_subprocess.py. SCALAR = 42 # TODO(antiagainst): Currently the warp count is chosen to make sure we don't have multiple diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index afce8d7f65a3..45c005e0ca2d 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -331,6 +331,7 @@ def kernel(x): kernel[(1, )](2) + def test_scalar_overflow(device): @triton.jit diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index bd0c80fca80e..c50d92788059 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -8,8 +8,6 @@ import pytest -import triton - dir_path = os.path.dirname(os.path.realpath(__file__)) print_path = os.path.join(dir_path, "print_helper.py") torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"] diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py index 58a52fb8f82f..ab132ce7cd8d 100644 --- a/python/triton/runtime/build.py +++ b/python/triton/runtime/build.py @@ -6,7 +6,6 @@ import os import shutil import subprocess -import setuptools @contextlib.contextmanager @@ -96,31 +95,4 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries): ret = subprocess.check_call(cc_cmd) if ret == 0: return so - # fallback on setuptools - extra_compile_args = [] - # extra arguments - extra_link_args = [] - # create extension module - ext = setuptools.Extension( - name=name, - language='c', - sources=[src], - include_dirs=include_dirs, - extra_compile_args=extra_compile_args + ['-O3'], - extra_link_args=extra_link_args, - library_dirs=library_dirs, - libraries=libraries, - ) - # build extension module - args = ['build_ext'] - args.append('--build-temp=' + srcdir) - args.append('--build-lib=' + srcdir) - args.append('-q') - args = dict( - name=name, - ext_modules=[ext], - script_args=args, - ) - with quiet(): - setuptools.setup(**args) return so diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index f0454993c129..2ae5345e95cd 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -526,6 +526,7 @@ def run(self, *args, grid, warmup, **kwargs): kwargs["debug"] = kwargs.get("debug", self.debug) or os.environ.get("TRITON_DEBUG", "0") == "1" # parse options + device_key = get_device_key() device = driver.active.get_current_device() stream = driver.active.get_current_stream(device) @@ -533,7 +534,7 @@ def run(self, *args, grid, warmup, **kwargs): for hook in self.pre_run_hooks: hook(*args, **kwargs) - kernel_cache, target, backend, binder = self.device_caches[device] + kernel_cache, target, backend, binder = self.device_caches[device_key] bound_args, specialization, options = binder(*args, **kwargs) # compute cache key @@ -675,6 +676,7 @@ def preload(self, specialization_data): import json import triton.language as tl device_key = get_device_key() + # device = driver.active.get_current_device() deserialized_obj = json.loads(specialization_data) if deserialized_obj['name'] != self.fn.__name__: raise RuntimeError( @@ -696,7 +698,7 @@ def preload(self, specialization_data): } key = deserialized_obj['key'] kernel = compile(src, None, options) - self.device_caches[device][0][key] = kernel + self.device_caches[device_key][0][key] = kernel return kernel # we do not parse `src` in the constructor because diff --git a/python/triton/testing.py b/python/triton/testing.py index 54d59de6066c..6d8443fce7d7 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -188,7 +188,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_m # Record clocks di.synchronize() - times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float) + times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)] return _summarize_statistics(times, quantiles, return_mode) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index f6eb176af82a..22d3d8ae8d0b 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -30,6 +30,7 @@ CPU_ST_THRESHOLD = 65536 USE_GPU = False + @triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. @@ -203,7 +204,7 @@ def add_tiled_autotuned(x: torch.Tensor, y: torch.Tensor, output): print(output_torch_gpu) print(output_triton_gpu) print(f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch_gpu - output_triton_gpu))}') + f'{torch.max(torch.abs(output_torch_gpu - output_triton_gpu))}') LINE_VALS += ['triton-gpu', 'torch-gpu'] LINE_NAMES += ['TritonGPU', 'TorchGPU'] diff --git a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp index f5ff91a43327..f2805935fa34 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/DebugOpsToLLVM.cpp @@ -296,7 +296,7 @@ struct AssertOpConversion makeNullTerminatedString(funcStr)); SmallVector args{getPid(op, 0), getPid(op, 1), getPid(op, 2), op.getCondition(), message, file, - b.i32_val(line), func}; + b.i32_val(line), func}; b.call(getAssertFuncDecl(rewriter), args); rewriter.eraseOp(op); return success(); diff --git a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp index 8ff02cf1fa70..ba40ae5fcf4d 100644 --- a/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp +++ b/third_party/cpu/lib/TritonCPUToLLVM/FuncOpToLLVM.cpp @@ -157,8 +157,8 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern { auto loc = op.getLoc(); auto b = TritonLLVMOpBuilder(loc, rewriter); for (auto it : llvm::enumerate(adaptor.getOperands())) { - packedResults = - b.insert_val(packedResultsTy, packedResults, it.value(), it.index()); + packedResults = b.insert_val(packedResultsTy, packedResults, it.value(), + it.index()); } newOp = rewriter.create(op.getLoc(), packedResults); } diff --git a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp index d4e5a646bf39..0347e0fb5628 100644 --- a/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp +++ b/third_party/cpu/lib/TritonCPUTransforms/DecomposeFpConversions.cpp @@ -222,14 +222,14 @@ Value convertToFp8(Location loc, Value src, Type dstFpTy, int dstExpBits, Value convertFp16ToFp8E4M3Rtz(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getType(), 4, 7, false, - false, rewriter); + return convertToFp8(loc, src, rewriter.getType(), 4, 7, + false, false, rewriter); } Value convertFp16ToFp8E4M3Rtne(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getType(), 4, 7, true, - false, rewriter); + return convertToFp8(loc, src, rewriter.getType(), 4, 7, + true, false, rewriter); } Value convertFp16ToFp8E5M2Rtz(Location loc, Value src, @@ -271,100 +271,100 @@ Value convertFp16ToFp8E5M2B16Rtz(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getType(), 5, 16, - false, true, rewriter); + return convertToFp8(loc, f32Src, rewriter.getType(), + 5, 16, false, true, rewriter); } Value convertFp16ToFp8E5M2B16Rtne(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getType(), 5, 16, - true, true, rewriter); + return convertToFp8(loc, f32Src, rewriter.getType(), + 5, 16, true, true, rewriter); } Value convertBf16ToFp8E4M3Rtz(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getType(), 4, 7, false, - false, rewriter); + return convertToFp8(loc, f32Src, rewriter.getType(), 4, + 7, false, false, rewriter); } Value convertBf16ToFp8E4M3Rtne(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getType(), 4, 7, true, - false, rewriter); + return convertToFp8(loc, f32Src, rewriter.getType(), 4, + 7, true, false, rewriter); } Value convertBf16ToFp8E5M2Rtz(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getType(), 5, 15, false, - false, rewriter); + return convertToFp8(loc, f32Src, rewriter.getType(), 5, + 15, false, false, rewriter); } Value convertBf16ToFp8E5M2Rtne(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getType(), 5, 15, true, - false, rewriter); + return convertToFp8(loc, f32Src, rewriter.getType(), 5, + 15, true, false, rewriter); } Value convertBf16ToFp8E5M2B16Rtz(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getType(), 5, 16, - false, true, rewriter); + return convertToFp8(loc, f32Src, rewriter.getType(), + 5, 16, false, true, rewriter); } Value convertBf16ToFp8E5M2B16Rtne(Location loc, Value src, PatternRewriter &rewriter) { Value f32Src = rewriter.create(loc, toFp32(src.getType()), src); - return convertToFp8(loc, f32Src, rewriter.getType(), 5, 16, - true, true, rewriter); + return convertToFp8(loc, f32Src, rewriter.getType(), + 5, 16, true, true, rewriter); } Value convertFp32ToFp8E4M3Rtz(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getType(), 4, 7, false, - false, rewriter); + return convertToFp8(loc, src, rewriter.getType(), 4, 7, + false, false, rewriter); } Value convertFp32ToFp8E4M3Rtne(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getType(), 4, 7, true, - false, rewriter); + return convertToFp8(loc, src, rewriter.getType(), 4, 7, + true, false, rewriter); } Value convertFp32ToFp8E5M2Rtz(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getType(), 5, 15, false, - false, rewriter); + return convertToFp8(loc, src, rewriter.getType(), 5, 15, + false, false, rewriter); } Value convertFp32ToFp8E5M2Rtne(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getType(), 5, 15, true, - false, rewriter); + return convertToFp8(loc, src, rewriter.getType(), 5, 15, + true, false, rewriter); } Value convertFp32ToFp8E5M2B16Rtz(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getType(), 5, 16, false, - true, rewriter); + return convertToFp8(loc, src, rewriter.getType(), 5, + 16, false, true, rewriter); } Value convertFp32ToFp8E5M2B16Rtne(Location loc, Value src, PatternRewriter &rewriter) { - return convertToFp8(loc, src, rewriter.getType(), 5, 16, true, - true, rewriter); + return convertToFp8(loc, src, rewriter.getType(), 5, + 16, true, true, rewriter); } FpToFpConvFn