Skip to content

Commit

Permalink
wip flash-infer
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Feb 20, 2025
1 parent a6c0438 commit 1df44c3
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 4 deletions.
85 changes: 85 additions & 0 deletions examples/deepseek-chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0

from vllm import LLM, SamplingParams

llm = LLM(
model="deepseek-ai/DeepSeek-V2-Lite",
trust_remote_code=True,
)
sampling_params = SamplingParams(temperature=0.5)


def print_outputs(outputs):
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print("-" * 80)


print("=" * 80)

# In this script, we demonstrate how to pass input to the chat method:

conversation = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
outputs = llm.chat(conversation,
sampling_params=sampling_params,
use_tqdm=False)
print_outputs(outputs)

# You can run batch inference with llm.chat API
conversation = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": "Hello"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "Write an essay about the importance of higher education.",
},
]
conversations = [conversation for _ in range(10)]

# We turn on tqdm progress bar to verify it's indeed running batch inference
outputs = llm.chat(messages=conversations,
sampling_params=sampling_params,
use_tqdm=True)
print_outputs(outputs)

# A chat template can be optionally supplied.
# If not, the model will use its default chat template.

# with open('template_falcon_180b.jinja', "r") as f:
# chat_template = f.read()

# outputs = llm.chat(
# conversations,
# sampling_params=sampling_params,
# use_tqdm=False,
# chat_template=chat_template,
# )
19 changes: 15 additions & 4 deletions vllm/attention/backends/flashinfer_mla.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

import copy
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type
Expand Down Expand Up @@ -125,7 +126,7 @@ def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
state = self.__class__(self.runner)
state._workspace_buffer = self._graph_decode_workspace_buffer
state._decode_wrapper = self._graph_decode_wrapper
state._decode_wrapper = copy.copy(self._graph_decode_wrapper)
return state

def graph_capture_get_metadata_for_batch(
Expand Down Expand Up @@ -197,10 +198,12 @@ def begin_forward(self, model_input):
# In case of multistep chunked-prefill, there might be prefill requests
# scheduled while CUDA graph mode is enabled. We don't run graph in that
# case.
print("begin_forward", model_input.input_tokens.shape[0])
if use_cuda_graph and is_decode:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
print("choosing decode_wrapper", batch_size)
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
model_input.attn_metadata.begin_forward()

Expand Down Expand Up @@ -421,9 +424,17 @@ def build(self, seq_lens: List[int], query_lens: List[int],
self.paged_kv_indptr.extend([self.paged_kv_indptr[-1]] *
cuda_graph_pad_size)
self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size)
query_start_loc_host = torch.functional.F.pad(
query_start_loc_host, (cuda_graph_pad_size + 1, ),
value=query_start_loc_host[-1].item())

print(cuda_graph_pad_size + 1 - query_start_loc_host.shape[0],
cuda_graph_pad_size + 1, query_start_loc_host.shape[0])
if cuda_graph_pad_size + 1 > query_start_loc_host.shape[0]:
query_start_loc_host = torch.cat(
(query_start_loc_host,
torch.full((cuda_graph_pad_size + 1 -
query_start_loc_host.shape[0], ),
fill_value=query_start_loc_host[-1].item(),
dtype=torch.int32,
device="cpu")))

if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
Expand Down

0 comments on commit 1df44c3

Please sign in to comment.