Skip to content

Commit

Permalink
Small fixes for clang + macosx (triton-lang#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
digantdesai authored and Devjiu committed Feb 20, 2025
1 parent 3d528f7 commit 4a778e6
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 8 deletions.
29 changes: 24 additions & 5 deletions python/src/llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,28 @@ struct BreakStructPhiNodesPass : PassInfoMixin<BreakStructPhiNodesPass> {

using namespace llvm;

std::string getDefaultTargerOrProcessTriple() {
// Return process triple iff the default target triple is empty.
std::string triple = llvm::sys::getDefaultTargetTriple();
if (triple.empty()) {
// host
triple = llvm::sys::getProcessTriple();
}
return triple;
}

std::unique_ptr<TargetMachine>
createTargetMachine(llvm::Module *module, std::string proc,
bool enable_fp_fusion, const std::string &features,
bool enable_fast_math = false) {
auto triple = getDefaultTargerOrProcessTriple();
module->setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module->getTargetTriple(), error);
if (!target) {
throw std::runtime_error("target lookup error: " + error);
}
llvm::TargetOptions opt;
bool disableLLVMOpt = mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT");
if (enable_fp_fusion)
Expand Down Expand Up @@ -418,10 +433,14 @@ void init_triton_llvm(py::module &&m) {
py::arg("enable_fp_fusion") = false);

m.def("set_host_target", [](llvm::Module *mod) {
mod->setTargetTriple(llvm::sys::getDefaultTargetTriple());
auto triple = getDefaultTargerOrProcessTriple();
mod->setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(mod->getTargetTriple(), error);
if (!target) {
throw std::runtime_error("target lookup error: " + error);
}
std::unique_ptr<llvm::TargetMachine> machine{target->createTargetMachine(
mod->getTargetTriple(), llvm::sys::getHostCPUName(), "", {},
llvm::Reloc::PIC_)};
Expand All @@ -448,10 +467,10 @@ void init_triton_llvm(py::module &&m) {
"failed to parse IR: " + error.getMessage() +
"lineno: " + std::to_string(error.getLineNo()));
}
res =
translateLLVMIRToASM(*module, llvm::sys::getDefaultTargetTriple(),
llvm::sys::getHostCPUName().str(), "", {},
enable_fp_fusion, false, enable_fast_math);
auto triple = getDefaultTargerOrProcessTriple();
res = translateLLVMIRToASM(*module, triple,
llvm::sys::getHostCPUName().str(), "", {},
enable_fp_fusion, false, enable_fast_math);
}
return py::str(res);
},
Expand Down
9 changes: 8 additions & 1 deletion python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,24 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
# for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-Wno-psabi", "-o", so]

libraries += ["gcc"]
# Use dynamic lookup to load Python library on Mac
if system == "Darwin":
cc_cmd += ["-undefined", "dynamic_lookup"]
# Don't use libgcc on clang + macos
if "clang" in cc:
libraries.remove("gcc")
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
for dir in library_dirs:
cc_cmd.extend(["-Wl,-rpath", dir])
# CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag.
if src.endswith(".cpp") or src.endswith(".cc"):
cc_cmd += ["-std=c++17", "-fopenmp"]
cc_cmd += ["-std=c++17"]
if not os.environ.get("TRITON_DISABLE_OPENMP", None):
cc_cmd += ["-fopenmp"]
if src.endswith(".s"):
# This is required to properly parse .file directives
cc_cmd += ["-g"]
Expand Down
2 changes: 1 addition & 1 deletion third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def make_so(src, metadata, options):
asm_path = os.path.join(tmpdir, "kernel.s")
Path(asm_path).write_text(src)
lib_dirs = cpu_driver.library_dirs
libs = ["gcc", "m", "TritonCPURuntime", "sleef"]
libs = ["m", "TritonCPURuntime", "sleef"]
so = _build("kernel", asm_path, tmpdir, lib_dirs, cpu_driver.include_dirs, libs)
with open(so, "rb") as f:
return f.read()
Expand Down
10 changes: 9 additions & 1 deletion third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ def format_of(ty):
#include <cstdlib>
#include <iomanip>
#include <iostream>
#ifdef _OPENMP
#include <omp.h>
#endif // _OPENMP
#include <optional>
#include <stdio.h>
#include <string>
Expand Down Expand Up @@ -238,7 +240,11 @@ def format_of(ty):
}}
auto all_grids = get_all_grids(gridX, gridY, gridZ);
int max_threads = (num_threads > 0) ? num_threads : omp_get_max_threads();
int omp_max_threads = 1;
#ifdef _OPENMP
omp_max_threads = omp_get_max_threads();
#endif // _OPENMP
int max_threads = (num_threads > 0) ? num_threads : omp_max_threads;
// Don't pay OMP overhead price when a single thread is used.
if (max_threads == 1) {{
Expand All @@ -250,7 +256,9 @@ def format_of(ty):
}}
// For now, use the default chunk size, total iterations / max_threads.
#ifdef _OPENMP
#pragma omp parallel for schedule(static) num_threads(max_threads)
#endif // _OPENMP
for (size_t i = 0; i < N; ++i) {{
const auto [x, y, z] = all_grids[i];
(*kernel_ptr)({kernel_fn_args_list + ', ' if len(kernel_fn_args) > 0 else ''} x, y, z, gridX, gridY, gridZ);
Expand Down

0 comments on commit 4a778e6

Please sign in to comment.