Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flashinfer paged attention #2772

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ void copy_blocks(
void reshape_and_cache(
torch::Tensor& key,
torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& kv_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype);

Expand Down
55 changes: 36 additions & 19 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,16 @@ template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
__global__ void reshape_and_cache_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
//cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
//cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
cache_t* __restrict__ kv_cache, // [num_blocks, 2, block_size, num_heads, head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride,
const int value_stride,
const int num_heads,
const int head_size,
const int block_size,
const int x) {
const int block_size
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
Expand All @@ -176,14 +177,16 @@ __global__ void reshape_and_cache_kernel(

const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
//ok
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;

//ok
const int head_idx = i / head_size;
const int head_offset = i % head_size;
const int x_idx = head_offset / x;
const int x_offset = head_offset % x;

//const int x_idx = head_offset / x;
//const int x_offset = head_offset % x;
/*
const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ head_idx * (head_size / x) * block_size * x
+ x_idx * block_size * x
Expand All @@ -193,18 +196,33 @@ __global__ void reshape_and_cache_kernel(
+ head_idx * head_size * block_size
+ head_offset * block_size
+ block_offset;
[num_blocks, 2, block_size, num_heads, head_size]
[num_blocks, num_heads, head_size, block_size]
*/
const int64_t tgt_key_idx = block_idx * 2 * block_size * num_heads * head_size
+ 0 * block_size * num_heads * head_size +
+ block_offset * num_heads * head_size
+ head_idx * head_size
+ head_offset;

const int64_t tgt_value_idx = block_idx * 2 * block_size * num_heads * head_size
+ 1 * block_size * num_heads * head_size +
+ block_offset * num_heads * head_size
+ head_idx * head_size
+ head_offset;

scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (is_fp8_e5m2_kv_cache) {
#ifdef ENABLE_FP8_E5M2
key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
kv_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
kv_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
#else
assert(false);
#endif
} else {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
kv_cache[tgt_key_idx] = tgt_key;
kv_cache[tgt_value_idx] = tgt_value;
}
}
}
Expand All @@ -215,29 +233,28 @@ __global__ void reshape_and_cache_kernel(
vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), \
key_stride, \
value_stride, \
num_heads, \
head_size, \
block_size, \
x);
block_size);

void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
//torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
//torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& kv_cache, // [num_blocks, 2, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype)
{
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int block_size = kv_cache.size(2);
//int x = kv_cache.size(4);

int key_stride = key.stride(0);
int value_stride = value.stride(0);
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
aioprometheus[starlette]
pynvml == 11.5.0
triton >= 2.1.0
flashinfer @ https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.1/flashinfer-0.0.1+cu121-cp310-cp310-linux_x86_64.whl
triton >= 2.1.0
2 changes: 2 additions & 0 deletions tests/async_engine/api_server_async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def stats() -> Response:
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

args["enforce_eager"] = True

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
vllm.entrypoints.api_server.engine = engine
Expand Down
6 changes: 4 additions & 2 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def test_duplicated_ignored_sequence_group():
max_tokens=256)
llm = LLM(model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1)
tensor_parallel_size=1,
enforce_eager=True)
prompts = ["This is a short prompt", "This is a very long prompt " * 1000]
outputs = llm.generate(prompts, sampling_params=sampling_params)

Expand All @@ -28,7 +29,8 @@ def test_max_tokens_none():
max_tokens=None)
llm = LLM(model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1)
tensor_parallel_size=1,
enforce_eager=True)
prompts = ["Just say hello!"]
outputs = llm.generate(prompts, sampling_params=sampling_params)

Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(
block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
kv_cache_dtype: str,
decode_wrapper = None,
prefill_wrapper = None
) -> None:
self.is_prompt = is_prompt
self.prompt_lens = prompt_lens
Expand All @@ -38,6 +40,8 @@ def __init__(
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
self.kv_cache_dtype = kv_cache_dtype
self.prefill_wrapper = prefill_wrapper
self.decode_wrapper = decode_wrapper

# Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack.
Expand Down
88 changes: 45 additions & 43 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512

import flashinfer

class PagedAttention(nn.Module):
"""MHA/MQA/GQA layer with PagedAttention.
Expand Down Expand Up @@ -63,8 +63,8 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
kv_cache: Optional[torch.Tensor],
#value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""PagedAttention forward pass.
Expand All @@ -81,49 +81,46 @@ def forward(
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""

batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
query = query.view(-1, self.num_heads, self.head_size).contiguous()
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
cache_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)
if kv_cache is not None:
cache_ops.reshape_and_cache(key, value, kv_cache,
input_metadata.slot_mapping.flatten(),
"auto")

if input_metadata.is_prompt:
# Prompt run.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :, None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# old attn
if kv_cache is None:
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.

if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
Expand Down Expand Up @@ -159,6 +156,18 @@ def forward(
(is_hip()) else None,
)
output = out.view_as(query)
elif input_metadata.block_tables.numel() == 0:
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.

query = query.view(-1, self.num_kv_heads, self.head_size)
#output = input_metadata.prefill_wrapper.forward(
# query, kv_cache, causal=True)

output = flashinfer.single_prefill_with_kv_cache(query, key.contiguous(), value.contiguous(), causal=True)
#allow_fp16_qk_reduction=True)

else:
# prefix-enabled attention
output = torch.empty_like(query)
Expand All @@ -167,8 +176,8 @@ def forward(
key,
value,
output,
key_cache,
value_cache,
#key_cache,
#value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
Expand All @@ -178,18 +187,11 @@ def forward(
)

else:
# Decoding run.
output = _paged_attention(
output = input_metadata.decode_wrapper.forward(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_cache,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)


Expand Down
7 changes: 4 additions & 3 deletions vllm/model_executor/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(self,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,

linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
Expand Down Expand Up @@ -145,13 +146,13 @@ def forward(
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
input_metadata: InputMetadata
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
#k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, kv_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def forward(
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
#key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, kv_cache,
input_metadata)
output, _ = self.out_proj(attn_output)
return output
Expand Down
Loading
Loading