Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix barrier insertion after assert op #5114

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
// know about the op to split the block.
void llAssert(Operation *op, Value condition, StringRef message,
ConversionPatternRewriter &rewriter) const {
ConversionPatternRewriter::InsertionGuard guard(rewriter);

auto ctx = rewriter.getContext();
auto loc = op->getLoc();
Expand Down Expand Up @@ -87,6 +86,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
rewriter.create<cf::BranchOp>(loc, thenBlock);
rewriter.setInsertionPointToEnd(prevBlock);
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
rewriter.setInsertionPointToStart(thenBlock);
}

protected:
Expand Down
19 changes: 16 additions & 3 deletions python/test/unit/test_debug.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import pytest
import torch
import triton.language as tl
Expand All @@ -10,8 +9,8 @@
@pytest.mark.parametrize('env_var', [True, False])
@pytest.mark.parametrize('jit_flag', [True, False])
@pytest.mark.forked
def test_device_assert(cond, opt_flag, env_var, jit_flag, device):
os.environ['TRITON_DEBUG'] = str(int(env_var))
def test_device_assert(monkeypatch, cond, opt_flag, env_var, jit_flag, device):
monkeypatch.setenv("TRITON_DEBUG", str(int(env_var)))
torch.zeros([1], dtype=torch.int32, device=device)

@triton.jit(debug=jit_flag)
Expand All @@ -34,6 +33,20 @@ def _kernel(COND: tl.constexpr):
getattr(torch, device).synchronize()


def test_device_assert_barrier(monkeypatch, device):
monkeypatch.setenv("TRITON_DEBUG", "1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general such bugs transformations it is better to do lit tests to test specifically the case that cause the bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux Thanks for the advice!

Should I try writing a lit test now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to replace this with a lit if you have a chance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this issue is somehow related to multithreading. I have a lit test (unfortunately only for the XPU backend) and if I also specify --mlir-disable-threading, the issue goes away. Also, it seems that this is still a problem of LLVM, could you give some advice on how to rewrite the reproducer in order to fill out an issue for them? Or maybe there are some ideas on what can be done about this?

Lit test
// RUN: triton-opt %s --convert-triton-intel-gpu-to-llvm

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
// CHECK: llvm.call @__assertfail
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @_kernel_qwerty(%arg0: !tt.ptr<i32>) {
    %cst = arith.constant dense<1> : tensor<8xi32, #blocked>
    %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked>
    %1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<8x!tt.ptr<i32>, #blocked>
    %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr<i32>, #blocked>, tensor<8xi32, #blocked>
    %3 = tt.load %2 : tensor<8x!tt.ptr<i32>, #blocked>
    %4 = arith.cmpi slt, %3, %cst : tensor<8xi32, #blocked>
    tt.assert %4, "" : tensor<8xi1, #blocked>
    tt.return
  }
}

Stack trace
RUN: at line 1: .../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt .../intel-xpu-backend-for-triton/test/Conversion/intel/tritonintelgpu_to_llvm.mlir --convert-triton-intel-gpu-to-llvm
+ .../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt .../intel-xpu-backend-for-triton/test/Conversion/intel/tritonintelgpu_to_llvm.mlir --convert-triton-intel-gpu-to-llvm
triton-opt: /home/runner/work/triton/triton/llvm-project/llvm/include/llvm/ADT/ilist_iterator.h:168: llvm::ilist_iterator::reference llvm::ilist_iterator<llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, false, false>::operator*() const [OptionsT = llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, IsReverse = false, IsConst = false]: Assertion `!NodePtr->isKnownSentinel()' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
 #0 0x000055d5cf9bf447 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x312f447)
 #1 0x000055d5cf9bcf6e llvm::sys::RunSignalHandlers() (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x312cf6e)
 #2 0x000055d5cf9bfaff SignalHandler(int) Signals.cpp:0:0
 #3 0x00007f25ba794520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007f25ba7e89fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
 #5 0x00007f25ba7e89fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10
 #6 0x00007f25ba7e89fc pthread_kill ./nptl/pthread_kill.c:89:10
 #7 0x00007f25ba794476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #8 0x00007f25ba77a7f3 abort ./stdlib/abort.c:81:7
 #9 0x00007f25ba77a71b _nl_load_domain ./intl/loadmsgcat.c:1177:9
#10 0x00007f25ba78be96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
#11 0x000055d5cf90c137 mlir::Operation::updateOrderIfNecessary() (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x307c137)
#12 0x000055d5cf90bf7f mlir::Operation::isBeforeInBlock(mlir::Operation*) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x307bf7f)
#13 0x000055d5cf8e4222 mlir::DominanceInfo::properlyDominates(mlir::Value, mlir::Operation*) const (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x3054222)
#14 0x000055d5cf92b224 (anonymous namespace)::OperationVerifier::verifyOpAndDominance(mlir::Operation&) Verifier.cpp:0:0
#15 0x000055d5cf92c9be std::_Function_handler<void (), llvm::LogicalResult mlir::failableParallelForEach<mlir::Operation**, (anonymous namespace)::OperationVerifier::verifyOnExit(mlir::Operation&)::$_3>(mlir::MLIRContext*, mlir::Operation**, mlir::Operation**, (anonymous namespace)::OperationVerifier::verifyOnExit(mlir::Operation&)::$_3&&)::'lambda'()>::_M_invoke(std::_Any_data const&) Verifier.cpp:0:0
#16 0x000055d5cea285e8 std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<std::function<void ()>>>, void>>::_M_invoke(std::_Any_data const&) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x21985e8)
#17 0x000055d5cea28547 std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x2198547)
#18 0x00007f25ba7ebee8 __pthread_once_slow ./nptl/pthread_once.c:118:7
#19 0x000055d5cea288fb std::__future_base::_Deferred_state<std::thread::_Invoker<std::tuple<std::function<void ()>>>, void>::_M_complete_async() (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x21988fb)
#20 0x000055d5cea289a9 void std::__invoke_impl<void, std::shared_future<void> llvm::ThreadPoolInterface::asyncImpl<void>(std::function<void ()>, llvm::ThreadPoolTaskGroup*)::'lambda'()&>(std::__invoke_other, std::shared_future<void> llvm::ThreadPoolInterface::asyncImpl<void>(std::function<void ()>, llvm::ThreadPoolTaskGroup*)::'lambda'()&) .../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x21989a9)
#21 0x000055d5cf99c1db llvm::StdThreadPool::processTasks(llvm::ThreadPoolTaskGroup*) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x310c1db)
#22 0x000055d5cf99d4d7 void* llvm::thread::ThreadProxy<std::tuple<llvm::StdThreadPool::grow(int)::$_0>>(void*) ThreadPool.cpp:0:0
#23 0x00007f25ba7e6ac3 start_thread ./nptl/pthread_create.c:442:8
#24 0x00007f25ba878850 ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:83:0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing the fact that it goes away without multi-threading is just that there is some memory corruption and it passes by luck otherwise. Is the barrier inserted at the right spot in this case? Checking that should be enough

If you think it is an issue in MLIR core you can file an issue in llvm github: https://github.com/llvm/llvm-project

tensor = torch.zeros([16], dtype=torch.int32, device=device)

@triton.jit
def _kernel(in_ptr0):
xindex = tl.arange(0, 8)
tmp0 = tl.load(in_ptr0 + xindex)
tl.device_assert(tmp0 < 1)

_kernel[(1, )](tensor)
getattr(torch, device).synchronize()


@pytest.mark.parametrize("cond", [False, True])
def test_static_assert(cond):

Expand Down
Loading