From b316ab22fef668f97765d0126bd64e4985b683e5 Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Mon, 27 Jan 2025 17:31:16 -0800 Subject: [PATCH] [Neuron][Kernel] NKI-based flash-attention kernel with paged KV cache (#11277) Signed-off-by: Liangfu Chen Co-authored-by: Jiangfei Duan --- .buildkite/run-neuron-test.sh | 2 +- tests/neuron/test_prefix_prefill.py | 456 ++++++++++++++++++ vllm/attention/ops/nki_flash_attn.py | 669 +++++++++++++++++++++++++++ 3 files changed, 1126 insertions(+), 1 deletion(-) create mode 100644 tests/neuron/test_prefix_prefill.py create mode 100644 vllm/attention/ops/nki_flash_attn.py diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh index 0590dad4f311f..1ad77cf50f612 100644 --- a/.buildkite/run-neuron-test.sh +++ b/.buildkite/run-neuron-test.sh @@ -54,4 +54,4 @@ docker run --rm -it --device=/dev/neuron0 --device=/dev/neuron1 --network host \ -e "NEURON_COMPILE_CACHE_URL=${NEURON_COMPILE_CACHE_MOUNT}" \ --name "${container_name}" \ ${image_name} \ - /bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py" + /bin/bash -c "python3 /workspace/vllm/examples/offline_inference/neuron.py && python3 -m pytest /workspace/vllm/tests/neuron/ -v --capture=tee-sys" diff --git a/tests/neuron/test_prefix_prefill.py b/tests/neuron/test_prefix_prefill.py new file mode 100644 index 0000000000000..77b707a737118 --- /dev/null +++ b/tests/neuron/test_prefix_prefill.py @@ -0,0 +1,456 @@ +import random +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F + + +class BlockDiagonalCausalFromBottomRightMask: + + @staticmethod + def _from_seqlens(query_lens, seq_lens, block_size=None): + from torch import logical_and, logical_or + + contexted = block_size is None + context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) + n_queries = sum(query_lens) + num_seqs = len(query_lens) + if contexted: + key_lens_blockaligned = seq_lens + else: + n_blocks_per_seq = (context_lens + block_size - 1) // block_size + offset_per_seq = n_blocks_per_seq * block_size + key_lens_blockaligned = offset_per_seq[:num_seqs].tolist() + n_keys = sum(key_lens_blockaligned) + + a = (torch.arange(n_queries).reshape(n_queries, + 1).expand(n_queries, n_keys)) + b = torch.arange(n_keys).reshape(1, n_keys).expand(n_queries, n_keys) + q_cumsum = torch.tensor([0] + query_lens).cumsum(dim=0) + k_cumsum = torch.tensor([0] + key_lens_blockaligned).cumsum(dim=0) + + prior_mask = torch.zeros(n_queries, n_keys) + new_masks: list[torch.Tensor] = [] + for seq_id in range(num_seqs): + ri = q_cumsum[seq_id] + ci = k_cumsum[seq_id] + nr = query_lens[seq_id] + + if contexted: + nc = seq_lens[seq_id] + a_offset = ci + nc - ri - nr + new_mask = (a + a_offset) >= b + else: + nc = context_lens[seq_id] + a_offset = ci + nc - 1 + new_mask = a_offset >= b + + left_mask = b >= ci + top_mask = a >= ri + bottom_mask = a < (ri + nr) + + new_mask = logical_and( + logical_and(logical_and(new_mask, left_mask), top_mask), + bottom_mask, + ) + prior_mask = logical_or(prior_mask, new_mask) + new_masks = new_masks + [new_mask] + return prior_mask + + @staticmethod + def from_seqlens(query_lens, seq_lens, block_size=None): + contexted = block_size is None + if contexted: + prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( + query_lens, seq_lens) + active_mask = None + else: + prior_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( + query_lens, seq_lens, block_size) + active_mask = BlockDiagonalCausalFromBottomRightMask._from_seqlens( + query_lens, query_lens) + return prior_mask, active_mask + + +def ref_softmax(x: torch.Tensor, + dim: int, + mixed_precision=False, + return_max_reduce=False): + max_value = torch.amax(x, dim=dim, keepdims=True) + exp = torch.exp(x - max_value) + if mixed_precision: + sum_value = torch.sum(exp.astype(torch.float32), + dim=dim, + keepdims=True).astype(x.dtype) + else: + sum_value = torch.sum(exp, dim=dim, keepdims=True) + if return_max_reduce: + return exp / sum_value, max_value, torch.reciprocal(sum_value) + return exp / sum_value + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, + return_max_reduce: Optional[bool] = False, +) -> torch.Tensor: + scaled_qk = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + masked_score = scaled_qk + attn_mask.float() + if return_max_reduce: + norm_score, cached_max, cached_sum_reciprocal = ref_softmax( + masked_score, dim=-1, return_max_reduce=True) + else: + norm_score = ref_softmax(masked_score, dim=-1) + out = torch.einsum("hqk,khd->qhd", norm_score, value) + if return_max_reduce: + return ( + out, + cached_max, + cached_sum_reciprocal, + norm_score, + masked_score, + scaled_qk, + ) + else: + return out + + +def ref_context_attention( + query, + key, + value, + query_lens, + seq_lens, + head_size, + num_kv_heads, + num_heads, + num_queries_per_kv, + return_max_reduce=False, +): + scale = float(1.0 / (head_size**0.5)) + if num_queries_per_kv > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) + + attn_mask, _ = BlockDiagonalCausalFromBottomRightMask.from_seqlens( + query_lens, seq_lens) + + # convert binary mask to -inf values + attn_mask = torch.logical_not(attn_mask) + attn_mask = attn_mask.float() * -30000 + + output, cached_max, cached_sum_reciprocal, lse, masked_score, scaled_qk = ( + ref_masked_attention( + query, + key, + value, + scale, + attn_mask, + return_max_reduce=return_max_reduce, + )) + + output = output.unsqueeze(1) + if return_max_reduce: + return ( + output, + cached_max, + cached_sum_reciprocal, + lse, + masked_score, + scaled_qk, + ) + else: + return output + + +@pytest.mark.parametrize( + "num_heads,num_queries_per_kv,head_size,mixed_precision", + [ + (4, 2, 8, False), + (4, 2, 8, True), + (32, 8, 64, True), + ], +) +@torch.inference_mode() +def test_contexted_kv_attention( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + mixed_precision: bool, +) -> None: + import os + + import torch_xla.core.xla_model as xm + + from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc + + device = xm.xla_device() + + os.environ["NEURON_CC_FLAGS"] = ( + " --model-type=transformer -O1 " + " --internal-hlo2tensorizer-options='--verify-hlo' ") + + random.seed(0) + torch.manual_seed(0) + torch.set_printoptions(sci_mode=False) + + min_ctx_len = 2 + max_ctx_len = 64 + min_query_len = 2 + max_query_len = 64 + prefill_batch_size = 2 + decode_batch_size = 6 + batch_size = prefill_batch_size + decode_batch_size + block_size = 32 + max_model_len = (max_query_len + max_ctx_len) * 4 + + max_block_per_request = max_model_len // block_size + dtype = torch.float32 + cache_size = (batch_size * max_block_per_request) + 2 + ctx_lens = [ + random.randint(min_ctx_len, max_ctx_len) + for _ in range(prefill_batch_size) + ] + [ + random.randint(min_ctx_len, max_ctx_len) + for _ in range(decode_batch_size) + ] + query_lens = [ + random.randint(min_query_len, max_query_len) + for _ in range(prefill_batch_size) + ] + [1 for _ in range(decode_batch_size)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] + num_kv_heads = num_heads // num_queries_per_kv + + num_tokens = sum(query_lens) + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + query.uniform_(-1, 1) + torch.empty(num_tokens, num_heads, head_size, dtype=dtype) + + kv = torch.empty(sum(seq_lens), 2, num_kv_heads, head_size, dtype=dtype) + kv.uniform_(-1, 1) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + v_cache = torch.zeros(cache_size, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) + values = values[torch.randperm(cache_size)] + block_table = values[:batch_size * max_block_per_request].view( + batch_size, max_block_per_request) + torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + dtype=torch.long), + dim=0) + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long), + dim=0) + for i in range(batch_size): + for j in range(query_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_kv_heads, + head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + + ( + output_ref, + cached_max, + cached_sum_reciprocal, + lse, + masked_score, + scaled_qk, + ) = ref_context_attention( + query, + key, + value, + query_lens, + seq_lens, + head_size, + num_kv_heads, + num_heads, + num_queries_per_kv, + return_max_reduce=True, + ) + + # build neuron program + return_debug_tensors = False + B_P_SIZE = 128 + LARGE_TILE_SZ = 2048 + max_num_queries = ( + (sum(query_lens) + block_size - 1) // block_size) * block_size + + def get_active_block_tables(block_tables, query_lens, seq_lens, block_size, + num_blocks): + context_lens = seq_lens - query_lens + blocks_per_seq = (context_lens + block_size - 1) // block_size + num_seqs = len(seq_lens) + active_blocks: list[int] = [] + for seq_id in range(num_seqs): + active_blocks = ( + active_blocks + + block_tables[seq_id, :blocks_per_seq[seq_id]].tolist()) + return F.pad( + torch.tensor(active_blocks), + (0, num_blocks - len(active_blocks)), + "constant", + 0, + ) + + def shift_bit_length(x): + return 1 << (x - 1).bit_length() + + # calculate input shapes + max_num_queries_shifted = shift_bit_length(max_num_queries) + max_num_queries_factor = B_P_SIZE // max_num_queries_shifted + max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor + assert (max_num_queries_padded == B_P_SIZE + ), "invalid {max_num_queries_padded=}" + head_size_padded = B_P_SIZE + context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens) + num_active_blocks_shifted = shift_bit_length( + ((context_lens + block_size - 1) // block_size).sum().item()) + num_active_blocks_factor = (LARGE_TILE_SZ // block_size // + num_active_blocks_shifted) + num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor + assert (num_active_blocks * + block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}" + context_kv_len = num_active_blocks * block_size + assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}" + + # pad QKV tensors + pad_dims = ( + 0, + head_size_padded - query.shape[2], + 0, + 0, + 0, + max_num_queries_padded - query.shape[0], + ) + query = F.pad(query, pad_dims, "constant", 0) + k = F.pad(k, pad_dims, "constant", 0) + v = F.pad(v, pad_dims, "constant", 0) + k_cache = F.pad(k_cache, (0, head_size_padded - head_size), "constant", 0) + v_cache = F.pad(v_cache, (0, head_size_padded - head_size), "constant", 0) + + # permute QKV tensors + # query: (1, n_heads, d, seq_q) + # key: (1, n_kv_heads, d, seq_k) + # value: (1, n_kv_heads, seq_v, d) + query = query.unsqueeze(0).permute(0, 2, 3, 1).contiguous() + k = k.unsqueeze(0).permute(0, 2, 3, 1).contiguous() + v = v.unsqueeze(0).permute(0, 2, 1, 3).contiguous() + + # transform block table + active_block_table = get_active_block_tables( + block_table, + torch.tensor(query_lens), + torch.tensor(seq_lens), + block_size, + num_active_blocks, + ) + + # Build attention masks + prior_mask, active_mask = ( + BlockDiagonalCausalFromBottomRightMask.from_seqlens( + query_lens, seq_lens, block_size=block_size)) + attn_mask = torch.concat( + [ + F.pad( + prior_mask, + ( + 0, + context_kv_len - prior_mask.shape[1], + 0, + B_P_SIZE - prior_mask.shape[0], + ), + "constant", + 0, + ).bool(), + F.pad( + active_mask, + ( + 0, + B_P_SIZE - active_mask.shape[1], + 0, + B_P_SIZE - active_mask.shape[0], + ), + "constant", + 0, + ).bool(), + ], + dim=1, + ) + + input_args = ( + query.to(device=device), + k.to(device=device), + v.to(device=device), + k_cache.to(device=device), + v_cache.to(device=device), + active_block_table.to(torch.int32).to(device=device), + attn_mask.to(device=device), + ) + input_kwargs = dict( + n_kv_head=num_kv_heads, + head_size=head_size, + mixed_precision=mixed_precision, + ) + + if return_debug_tensors: + output_nki, *debug_tensors = flash_attn_varlen_nkifunc( + *input_args, **input_kwargs) + else: + output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs) + debug_tensors = [] + + output_nki = torch.tensor(output_nki).cpu() + debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors] + + num_actual_tokens = sum(query_lens) + print(f"{num_actual_tokens=}") + # - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d) + output_nki = output_nki.permute( + 0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :] + output_ref_padded = F.pad( + output_ref, + (0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]), + "constant", + 0, + ) + output_ref = output_ref_padded.transpose(0, 1)[0, :num_actual_tokens, :, :] + + torch.testing.assert_close(output_nki, output_ref, atol=1e-2, rtol=0) diff --git a/vllm/attention/ops/nki_flash_attn.py b/vllm/attention/ops/nki_flash_attn.py new file mode 100644 index 0000000000000..b9765b0f0283d --- /dev/null +++ b/vllm/attention/ops/nki_flash_attn.py @@ -0,0 +1,669 @@ +from dataclasses import dataclass + +import neuronxcc.nki.isa as nisa +import neuronxcc.nki.language as nl +import numpy as np +from neuronxcc import nki +from neuronxcc.nki.language import par_dim + + +@dataclass(frozen=True) +class FlashConfig: + """ + Config class for flash attention with default values + """ + + seq_tile_size: int = 2048 + should_transpose_v: bool = False + + __annotations__ = { + "seq_tile_size": int, + "should_transpose_v": bool, + } + + +@nki.jit +def transpose_p_local(p_local_transposed, + p_local, + LARGE_TILE_SZ, + forward_mask, + B_F_SIZE=512): + for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.sbuf, + dtype=p_local.dtype) + else: + p_local_t_tmp = nl.ndarray((par_dim(128), B_F_SIZE), + buffer=nl.psum, + dtype=np.float32) + + for j in nl.affine_range(B_F_SIZE // 128): + j_128_slice = nl.ds(j * 128, 128) + i_j_128_slice = nl.ds(i * B_F_SIZE + j * 128, 128) + + if nisa.get_nc_version() == nisa.nc_version.gen3: + p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose( + p_local[:, i_j_128_slice], mask=forward_mask) + else: + p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose( + p_local[:, i_j_128_slice], mask=forward_mask) + + p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy( + p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask) + + +@nki.jit +def _flash_attention_core( + q_local_tile, + k, + v, + q_h_per_k_h, + seqlen_q, + nheads, + o_buffer, + l_buffer, + m_buffer, + batch_id, + head_id, + gqa_head_idx, + q_tile_idx, + local_k_large_tile_idx, + kernel_dtype, + acc_type, + flash_config: FlashConfig, + use_causal_mask=False, + continuous_batching_mask=None, + initialize=False, + B_P_SIZE=128, + B_F_SIZE=512, + B_D_SIZE=128, + dropout_p=0.0, + dropout_p_tensor=None, + seed_tensor=None, + logit_bias_tile=None, + qk_res_buffer=None, +): + """ + The flash attention core function to calculate self attention between a tile + of q and a block of K and V. + The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF + already. The block size of K and V + is defined in the seq_tile_size of the flash_config. The results are stored + in the following three buffers + o_buffer: (B_P_SIZE, d) + l_buffer: (B_P_SIZE, 1) + m_buffer: (B_P_SIZE, 1) + """ + LARGE_TILE_SZ = flash_config.seq_tile_size + num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE + seqlen_k = k.shape[-1] + seqlen_q // B_P_SIZE + seqlen_k // B_F_SIZE + + # TODO : support logit_bias with continuous_batching_mask + assert not use_causal_mask, "causal mask is not supported." + assert (continuous_batching_mask + is not None), "continuous_batching_mask input is required." + if continuous_batching_mask is not None: + assert (logit_bias_tile is + None), "continuous_batching_mask does not support logit_bias!" + + # mask are used to only apply computation to the lower half of the matrix, + # which reduce the arthimetic intensity by half + forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx * + LARGE_TILE_SZ if use_causal_mask else None) + + qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + buffer=nl.sbuf, + dtype=acc_type) + max_local = nl.ndarray((par_dim(B_P_SIZE), num_k_tile_per_large_tile), + dtype=acc_type) + for k_i in nl.affine_range(num_k_tile_per_large_tile): + k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE) + + qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE), + dtype=np.float32, + buffer=nl.psum) # (128, 512) + qk_psum[:, :] = nl.matmul(q_local_tile, + k[:, k_i_b_f_slice], + transpose_x=True, + mask=None) # (p(128), 512) + + qk_res_buf[:, k_i_b_f_slice] = nl.where( + continuous_batching_mask[:, k_i_b_f_slice], + qk_psum[:, nl.ds(0, B_F_SIZE)], + -9984.0, + dtype=acc_type, + ) + + # Calculate max of the current tile + max_local[:, k_i] = nisa.tensor_reduce( + np.max, + qk_res_buf[:, k_i_b_f_slice], + axis=(1, ), + dtype=acc_type, + negate=False, + mask=forward_mask, + ) + + if qk_res_buffer is not None: + qk_res_buffer[:, :] = nl.copy(qk_res_buf[:, :]) + + max_ = nisa.tensor_reduce( + np.max, + max_local[:, :], + axis=(1, ), + dtype=acc_type, + negate=False, + mask=forward_mask, + ) + + o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE), + dtype=o_buffer.dtype) + + if initialize: + m_buffer[:, 0] = nl.copy(max_) + m_current = max_ + else: + m_previous = nl.copy(m_buffer[:, 0]) + m_buffer[:, 0] = nl.maximum(m_previous, max_, + mask=forward_mask) # (128,1) + + m_current = m_buffer[:, 0] + # Compute scaling factor + alpha = nisa.activation( + np.exp, + m_previous, + bias=-1 * m_current, + scale=1.0, + mask=forward_mask, + ) + o_previous_scaled[...] = nl.multiply(o_buffer[:, :], + alpha, + mask=forward_mask) + + p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + REDUCTION_TILE = min(2048, LARGE_TILE_SZ // 2) + + p_partial_sum = nl.ndarray( + (par_dim(B_P_SIZE), LARGE_TILE_SZ // REDUCTION_TILE), dtype=acc_type) + + for k_r_i in nl.affine_range(LARGE_TILE_SZ // REDUCTION_TILE): + k_r_i_reduce_slice = nl.ds(k_r_i * REDUCTION_TILE, REDUCTION_TILE) + + # compute exp(qk - max) + # Compute partial row - tile sum of exp(qk - max)) + # FIXME : Use activation accumulate to accumulate over k_r_i loop ? + p_local[:, k_r_i_reduce_slice] = nisa.activation_reduce( + np.exp, + qk_res_buf[:, k_r_i_reduce_slice], + bias=-1 * m_current, + scale=1.0, + reduce_op=nl.add, + reduce_res=p_partial_sum[:, k_r_i], + dtype=kernel_dtype, + mask=forward_mask, + ) + + ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask) + + p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + transpose_p_local( + p_local_transposed=p_local_transposed, + p_local=p_local, + LARGE_TILE_SZ=LARGE_TILE_SZ, + forward_mask=forward_mask, + B_F_SIZE=B_F_SIZE, + ) + + pv_psum = nl.zeros((par_dim(B_P_SIZE), B_D_SIZE), + dtype=np.float32, + buffer=nl.psum) + for k_i in nl.affine_range(LARGE_TILE_SZ // B_P_SIZE): + pv_psum[:, :] += nl.matmul( + p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)], + v[k_i, :, :], + transpose_x=True, + mask=forward_mask, + ) # (128, 128) (p(Br), d) + + if initialize: + o_buffer[:, :] = nl.copy(pv_psum[:, :]) + l_buffer[:, 0] = nl.add(nl.log(ps), max_) + else: + o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask) + + l_prev = l_buffer[:, 0] + l_exp = nl.add( + nl.exp( + nl.subtract(l_prev, m_current, mask=forward_mask), + mask=forward_mask, + ), + ps, + mask=forward_mask, + ) + l_buffer[:, 0] = nl.add(m_current, + nl.log(l_exp, mask=forward_mask), + mask=forward_mask) + + +@nki.jit +def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config): + LARGE_TILE_SZ = config.seq_tile_size + B_P_SIZE = 128 + + if not config.should_transpose_v: + cur_v_tile[v_i, :, :] = nl.load( + v_hbm_tile[nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE), :], + dtype=cur_v_tile.dtype, + ) + return + + if nisa.get_nc_version() == nisa.nc_version.gen3: + cur_v_tile_transposed = nisa.dma_transpose( + v_hbm_tile[:, + nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)]) + cur_v_tile[v_i, :, :] = nisa.tensor_copy(cur_v_tile_transposed, + dtype=cur_v_tile.dtype) + return + + cur_v_tile[v_i, :, :] = nl.load_transpose2d( + v_hbm_tile[:, nl.ds(j * LARGE_TILE_SZ + B_P_SIZE * v_i, B_P_SIZE)], + dtype=cur_v_tile.dtype, + ) + + +@nki.jit +def flash_paged_attention( + query, + key, + value, + key_cache, + value_cache, + block_tables, + mask, + softmax_scale=None, + mixed_precision=True, + config=None, + return_debug_tensors=False, +): + """ + Flash PagedAttention Forward Kernel. + - PagedAttention Paper: https://arxiv.org/abs/2309.06180 + - Chunked Prefill Paper: https://arxiv.org/abs/2403.02310 + + IO tensor layouts: + - query: shape (1, n_heads, d, seq_q) + - key: shape (1, n_kv_heads, d, seq_k) + - value: shape (1, n_kv_heads, seq_v, d) + - key_cache: (num_blocks, block_size, n_kv_heads, d) + - value_cache: (num_blocks, block_size, n_kv_heads, d) + - block_tables: (num_active_blocks, ) + - mask: (seq_q, num_active_blocks * block_size) + - o: shape (1, n_heads, seq_q, d) + - l_m: shape (1, n_heads, seq_q, 2) + + - This kernel requires seq_k == seq_v + - We use continuous batching by default, so the batch dimension is + always 1, and different requests are concatenated along sequence + dimension. + - We use paged cache blocks (key_cache, value_cache) to store KV cache. + + IO tensor dtypes: + - This kernel assumes all IO tensors have the same dtype except for + block_tables (int32) and mask (int32) + - If mixed_percision is True, then all Tensor Engine operation will be + performed in bfloat16 and accumulation will be performed in float32. + Otherwise the intermediates will be in the same type as the inputs. + + Compile-time Constants: + - softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)` + - mixed_precision: flag to set non-matmul ops in fp32 precision, default + is set to `true`, if false, we use same precision as input types + - config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig` + with Performance config parameters for flash attention with default + values + seq_tile_size: `default=2048`, size of the kv tile size for attention + computation reduction + + GQA support Notes: + the spmd kernel for launching kernel should be on kv_heads instead of + nheads + + Example usage: + MHA: q: [b, h, d, s], k: [b, h, d, s], v: [b, h, s, d] + usage: `flash_fwd[b, h](q, k, v, ...)` + GQA: q: [b, h, d, s], k: [b, kv_h, d, s], v: [b, kv_h, s, d] + usage: `flash_fwd[b, kv_h](q, k, v, ...)` + """ + config = config or FlashConfig() + B_F_SIZE = 512 + B_P_SIZE = 128 + b, h, d, seqlen_q = query.shape + B_D_SIZE = d + LARGE_TILE_SZ = config.seq_tile_size + n_tile_q = seqlen_q // B_P_SIZE # since q will be loaded on tensor engine + num_blocks, block_size, k_h, _ = key_cache.shape + q_h_per_k_h = h // k_h + assert tuple(key_cache.shape) == ( + num_blocks, + block_size, + k_h, + d, + ), "Input shape mismatch!" + assert tuple(value_cache.shape) == ( + num_blocks, + block_size, + k_h, + d, + ), "Input shape mismatch!" + assert b == 1, f"invalid batch size {b=}" + assert d <= 128, f" we do not support head_dim > 128, got head dim {d}" + kernel_dtype = nl.bfloat16 if mixed_precision else query.dtype + acc_type = np.dtype(np.float32) if mixed_precision else kernel_dtype + + o = nl.ndarray((b, h, seqlen_q, d), + dtype=query.dtype, + buffer=nl.shared_hbm) + hbm_l_buffer, hbm_m_buffer, hbm_qk_res, qk_res_buffer = ( + None, + None, + None, + None, + ) + if return_debug_tensors: + hbm_l_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_m_buffer = nl.ndarray((b, h, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + hbm_qk_res = nl.ndarray((b, h, B_P_SIZE, seqlen_q), + dtype=acc_type, + buffer=nl.shared_hbm) + qk_res_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), seqlen_q), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + + assert ( + nl.program_ndim() == 2 + ), f"Expect spmd grid with 2 dimensions, got {nl.program_ndim()} instead!" + batch_id = nl.program_id(axis=0) + head_id = nl.program_id(axis=1) + + softmax_scale = softmax_scale or (1.0 / (d**0.5)) + + (num_active_blocks, ) = block_tables.shape + context_kv_len = num_active_blocks * block_size + assert (config.seq_tile_size >= 512 + ), f" seq tile_size {config.seq_tile_size} cannot be less than 512" + assert (context_kv_len % LARGE_TILE_SZ == 0 + ), f"Need {context_kv_len=} to be divisible by {LARGE_TILE_SZ=}" + assert ( + LARGE_TILE_SZ % B_P_SIZE == 0 + ), f"Need LARGE_TILE_SZ ({LARGE_TILE_SZ}) to be divisible by {B_P_SIZE=}" + assert (B_P_SIZE % block_size == 0 + ), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}" + num_large_k_tile = context_kv_len // LARGE_TILE_SZ + num_blocks_per_large_tile = LARGE_TILE_SZ // block_size + assert (num_blocks_per_large_tile <= B_P_SIZE + ), f"The number of blocks in each large tile " \ + f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}" + + block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile), + 0, + dtype=np.int32, + buffer=nl.sbuf) + for j in nl.affine_range(num_large_k_tile): + i_p = nl.arange(num_blocks_per_large_tile)[:, None] + block_tables_sbuf[i_p, j] = nl.load( + block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32) + + # Global Flash Attention accumulators + o_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), d), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + l_buffer = nl.zeros( + (par_dim(B_P_SIZE), n_tile_q, q_h_per_k_h), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + m_buffer = nl.zeros( + (n_tile_q, q_h_per_k_h, par_dim(B_P_SIZE), 1), + dtype=acc_type, + buffer=nl.sbuf, + lazy_initialization=True, + ) + + for j in nl.sequential_range(0, num_large_k_tile): + cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + cur_v_tile = nl.ndarray( + (LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), + dtype=kernel_dtype, + ) + + for k_i in nl.affine_range(num_blocks_per_large_tile): + loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :, + head_id, :]) + cur_k_tile[:, nl.ds(k_i * + block_size, block_size)] = nl.transpose(loaded) + + load_tile_size = B_P_SIZE + num_blocks_per_partition = load_tile_size // block_size + for partition_idx in nl.affine_range(LARGE_TILE_SZ // load_tile_size): + for block_in_partition in nl.affine_range( + num_blocks_per_partition): + v_i = (partition_idx * num_blocks_per_partition + + block_in_partition) + loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :, + head_id, :]) + cur_v_tile[partition_idx, + nl.ds(block_in_partition * + block_size, block_size), :, ] = loaded_v + + cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ), + dtype=mask.dtype) + for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE): + cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load( + mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)]) + + for i_q_h in nl.affine_range(q_h_per_k_h): + for i in nl.affine_range(n_tile_q): + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load( + q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], + dtype=kernel_dtype, + ) # load (d, 128) tile in SBUF + q_tile[:, :] = q_sbuf_tile * softmax_scale + + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + q_h_per_k_h=q_h_per_k_h, + seqlen_q=seqlen_q, + nheads=h, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[:, i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + batch_id=batch_id, + head_id=head_id, + gqa_head_idx=i_q_h, + q_tile_idx=i, + local_k_large_tile_idx=j, + kernel_dtype=kernel_dtype, + acc_type=acc_type, + flash_config=config, + use_causal_mask=False, + continuous_batching_mask=cur_mask, + initialize=j == 0, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + dropout_p=0.0, + dropout_p_tensor=None, + seed_tensor=None, + logit_bias_tile=None, + ) + + # compute attention between input query, key and value + if key is not None and value is not None: + B_F_SIZE = seqlen_q + LARGE_TILE_SZ = seqlen_q + active_config = FlashConfig( + seq_tile_size=LARGE_TILE_SZ, + should_transpose_v=config.should_transpose_v, + ) + + cur_k_tile = nl.ndarray((par_dim(B_D_SIZE), LARGE_TILE_SZ), + dtype=kernel_dtype) + cur_v_tile = nl.ndarray( + (LARGE_TILE_SZ // B_P_SIZE, par_dim(B_P_SIZE), B_D_SIZE), + dtype=kernel_dtype, + ) + + cur_k_tile[:, :] = nl.load(key[batch_id, head_id, :, :]) + + load_tile_size = B_P_SIZE + v_hbm_tile = value[batch_id, head_id] + for v_i in nl.affine_range(LARGE_TILE_SZ // load_tile_size): + load_v_tile( + v_hbm_tile=v_hbm_tile, + cur_v_tile=cur_v_tile, + j=0, + v_i=v_i, + config=active_config, + ) + + cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype) + cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)]) + + for i_q_h in nl.affine_range(q_h_per_k_h): + for i in nl.affine_range(n_tile_q): + q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype) + q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h] + q_sbuf_tile = nl.load( + q_hbm_tile[:, nl.ds(i * B_P_SIZE, B_P_SIZE)], + dtype=kernel_dtype, + ) # load (d, 128) tile in SBUF + q_tile[:, :] = q_sbuf_tile * softmax_scale + _flash_attention_core( + q_local_tile=q_tile, + k=cur_k_tile, + v=cur_v_tile, + q_h_per_k_h=q_h_per_k_h, + seqlen_q=seqlen_q, + nheads=h, + o_buffer=o_buffer[i, i_q_h], + l_buffer=l_buffer[:, i, i_q_h], + m_buffer=m_buffer[i, i_q_h], + batch_id=batch_id, + head_id=head_id, + gqa_head_idx=i_q_h, + q_tile_idx=i, + local_k_large_tile_idx=0, + kernel_dtype=kernel_dtype, + acc_type=acc_type, + flash_config=active_config, + use_causal_mask=False, + continuous_batching_mask=cur_mask, + initialize=False, + B_P_SIZE=B_P_SIZE, + B_F_SIZE=B_F_SIZE, + B_D_SIZE=B_D_SIZE, + dropout_p=0.0, + dropout_p_tensor=None, + seed_tensor=None, + logit_bias_tile=None, + qk_res_buffer=qk_res_buffer[i, i_q_h] + if qk_res_buffer is not None else None, + ) + + # -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- # + for i_q_h in nl.affine_range(q_h_per_k_h): + for i in nl.affine_range(n_tile_q): + out = nl.multiply( + o_buffer[i, i_q_h, :, :], + nl.exp(m_buffer[i, i_q_h, :, :] - l_buffer[:, i, i_q_h]), + dtype=kernel_dtype, + ) + + nl.store( + o[batch_id, head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), :, ], + out, + ) + # maximum and summation statistics + if return_debug_tensors: + nl.store( + hbm_m_buffer[batch_id, head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), ], + m_buffer[i, i_q_h, :, :], + ) + nl.store( + hbm_l_buffer[batch_id, head_id * q_h_per_k_h + i_q_h, + nl.ds(i * B_P_SIZE, B_P_SIZE), ], + l_buffer[:, i, i_q_h], + ) + nl.store( + hbm_qk_res[batch_id, head_id * q_h_per_k_h + i_q_h, :, :], + qk_res_buffer[batch_id, i_q_h, :, :], + ) + + if return_debug_tensors: + return o, hbm_m_buffer, hbm_l_buffer, hbm_qk_res + return o + + +def flash_attn_varlen_nkifunc( + query, + key, + value, + key_cache, + value_cache, + block_table, + attn_mask, + n_kv_head=None, + head_size=None, + B_P_SIZE=128, + LARGE_TILE_SZ=2048, + return_debug_tensors=False, + mixed_precision=True, +): + config = FlashConfig( + seq_tile_size=LARGE_TILE_SZ, + should_transpose_v=False, + ) + kwargs = dict( + query=query, + key=key, + value=value, + key_cache=key_cache, + value_cache=value_cache, + block_tables=block_table, + mask=attn_mask, + softmax_scale=1.0 / (head_size**0.5), + config=config, + mixed_precision=mixed_precision, + return_debug_tensors=return_debug_tensors, + ) + _, n_kv_head, _, _ = key.shape + + if return_debug_tensors: + o, *debug_tensors = flash_paged_attention[1, n_kv_head](**kwargs) + return o, *debug_tensors + else: + o = flash_paged_attention[1, n_kv_head](**kwargs) + return o