Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missed precommit format files (NKI FA) #12497

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions vllm/attention/ops/nki_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ def _flash_attention_core(
assert (continuous_batching_mask
is not None), "continuous_batching_mask input is required."
if continuous_batching_mask is not None:
assert (logit_bias_tile is
None), "continuous_batching_mask does not support logit_bias!"
assert (
logit_bias_tile
is None), "continuous_batching_mask does not support logit_bias!"

# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arthimetic intensity by half
# which reduce the arithmetic intensity by half
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
LARGE_TILE_SZ if use_causal_mask else None)

Expand Down Expand Up @@ -468,9 +469,11 @@ def flash_paged_attention(
block_in_partition)
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
head_id, :])
cur_v_tile[partition_idx,
nl.ds(block_in_partition *
block_size, block_size), :, ] = loaded_v
cur_v_tile[
partition_idx,
nl.ds(block_in_partition * block_size, block_size),
:,
] = loaded_v

cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
dtype=mask.dtype)
Expand Down Expand Up @@ -601,20 +604,30 @@ def flash_paged_attention(
)

nl.store(
o[batch_id, head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE), :, ],
o[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
:,
],
out,
)
# maximum and summation statistics
if return_debug_tensors:
nl.store(
hbm_m_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
hbm_m_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
m_buffer[i, i_q_h, :, :],
)
nl.store(
hbm_l_buffer[batch_id, head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE), ],
hbm_l_buffer[
batch_id,
head_id * q_h_per_k_h + i_q_h,
nl.ds(i * B_P_SIZE, B_P_SIZE),
],
l_buffer[:, i, i_q_h],
)
nl.store(
Expand Down
8 changes: 4 additions & 4 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,10 +870,10 @@ def _verify_tokens(
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b
# Drop non-terminal prefill chunks hidden states.
hidden_states = hidden_states[
accepted_index != VLLM_INVALID_TOKEN_ID]
accepted_index = accepted_index[
accepted_index != VLLM_INVALID_TOKEN_ID]
hidden_states = hidden_states[accepted_index !=
VLLM_INVALID_TOKEN_ID]
accepted_index = accepted_index[accepted_index !=
VLLM_INVALID_TOKEN_ID]
assert len(accepted_index) == hidden_states.shape[0] == len(
terminal_metadata)
index = accepted_index[:, None, None].expand(-1, 1,
Expand Down
Loading