Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error: Incompatible parent encoding due to use of mask in load #1528

Closed
3outeille opened this issue Apr 14, 2023 · 1 comment
Closed

error: Incompatible parent encoding due to use of mask in load #1528

3outeille opened this issue Apr 14, 2023 · 1 comment

Comments

@3outeille
Copy link

Given the following matrix multiplication (which I ran on a V100), when IS_MASKED=True, the error Incompatible parent encoding appear which is not the case when IS_MASKED=False. Any idea what caused this problem ?

@triton.jit
def _kernel(debug_qk_ptr, q_ptr, k_ptr, q_batch_stride, q_head_stride, q_m_stride, q_k_stride, k_batch_stride, k_head_stride, k_n_stride, k_k_stride,head_size, seq_len, IS_MASKED: tl.constexpr, BLOCK: tl.constexpr, BLOCK_DHEAD_SIZE: tl.constexpr):
    
    block_m_idx = tl.program_id(0)
    head_idx = tl.program_id(1)
    current_batch_idx = head_idx // head_size
    current_head_idx = head_idx % head_size

    m_range_offs = tl.arange(0, BLOCK)
    n_range_offs = tl.arange(0, BLOCK)
    dhead_range_offs = tl.arange(0, BLOCK_DHEAD_SIZE)

    m_offs = block_m_idx * BLOCK + m_range_offs

    q_offs = (
        current_batch_idx * q_batch_stride + 
        current_head_idx * q_head_stride + 
        (m_offs[:, None] * q_m_stride + dhead_range_offs[None, :] * q_k_stride)
    )

    k_offs = (
        current_batch_idx * k_batch_stride
        + current_head_idx * k_head_stride
        + (n_range_offs[None, :] * k_n_stride + dhead_range_offs[:, None] * k_k_stride)
    )
    
    q_ptrs = q_ptr + q_offs
    k_ptrs = k_ptr + k_offs

    if IS_MASKED:
        q = tl.load(q_ptrs, mask=m_offs[:, None] < seq_len, other=0.0)
    else:
        q = tl.load(q_ptrs)

    for block_n_start_idx in range(0, seq_len, BLOCK):
        block_n_offs = block_n_start_idx + n_range_offs
        
        if IS_MASKED:
            k_ptr_mask = block_n_offs[None, :] < seq_len
            k = tl.load(k_ptrs + block_n_start_idx * k_n_stride, mask=k_ptr_mask, other=0.0)
        else:
            k = tl.load(k_ptrs + block_n_start_idx * k_n_stride)

        qk = tl.zeros((BLOCK, BLOCK), dtype=tl.float32)

        if IS_MASKED:
            qk = tl.where(n_range_offs[None, :] < seq_len, qk, float("-inf"))
        
        qk += tl.dot(q, k)

        debug_qk_ptrs = debug_qk_ptr + tl.arange(0, BLOCK)[None, :] + tl.arange(0, BLOCK)[:, None] * BLOCK
        tl.store(debug_qk_ptrs, qk)


BLOCK_SIZE = 64
Z, H, N_CTX, D_HEAD = 1, 1, 32, 32
q = torch.arange(Z * H * N_CTX * D_HEAD, dtype=torch.float16, device="cuda").reshape(Z, H, N_CTX, D_HEAD) / 100
k = torch.arange(Z * H * N_CTX * D_HEAD, dtype=torch.float16, device="cuda").reshape(Z, H, N_CTX, D_HEAD) / 100
debug_qk = torch.zeros((BLOCK_SIZE, BLOCK_SIZE), dtype=torch.float32, device=q.device)

grid = (1, 1)

_kernel[grid](debug_qk, q, k, *q.stride(), *k.stride(), H,  N_CTX, IS_MASKED=True, BLOCK=BLOCK_SIZE, BLOCK_DHEAD_SIZE=D_HEAD, num_warps=4 if k.size(3) <= 64 else 8, num_stages=2)

ref_out = torch.matmul(q.to(torch.float32), k.transpose(3, 2).to(torch.float32))
assert torch.allclose(ref_out[0, 0, ...], debug_qk[:32, :32])
@3outeille
Copy link
Author

3outeille commented Apr 14, 2023

Seems like commenting out this line if IS_MASKED: qk = tl.where(n_range_offs[None, :] < seq_len, qk, float("-inf")), fixed the issue (I was multiplying with -inf, that is why it doesnt work)

ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this issue Aug 5, 2024
These fixes allow the Triton project to build under gcc-9.

cc triton-lang#1505
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant