Skip to content

Commit

Permalink
Merge pull request #3 from gmlwns2000/hip12-offload-add-hip
Browse files Browse the repository at this point in the history
update
  • Loading branch information
gmlwns2000 authored Dec 26, 2024
2 parents 68a3150 + 5335cf2 commit 422415b
Show file tree
Hide file tree
Showing 13 changed files with 327 additions and 61 deletions.
27 changes: 27 additions & 0 deletions python/sglang/srt/layers/attention/hip_attention/hip_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,24 @@ def __post_init__(self, parsed_json: dict | None):
if parsed_json is not None:
if 'second_stage_k' in parsed_json:
self.second_stage_k = parsed_json['second_stage_k']
parsed_json.pop('second_stage_k')
if 'sliding_window_size' in parsed_json:
self.sliding_window_size = parsed_json['sliding_window_size']
parsed_json.pop('sliding_window_size')
if 'sink_token_size' in parsed_json:
self.sink_token_size = parsed_json['sink_token_size']
parsed_json.pop('sink_token_size')
if 'sa_extend_backend' in parsed_json:
self.sa_extend_backend = parsed_json['sa_extend_backend']
parsed_json.pop('sa_extend_backend')
if 'stages' in parsed_json:
self.stages = [
ScanStage(**stage)
for stage in parsed_json['stages']
]
parsed_json.pop('stages')
if parsed_json:
raise ValueError(f'Unknown keys in json: {parsed_json.keys()}')


@dataclass
Expand All @@ -65,6 +72,8 @@ class HiPAttentionConfig:
force_dense: bool = False
prefill_dense_threshold: int = 8192
block_sparse_block_size_q: int = 64
metadata_cache_max_batch_size: int = 256
mask_refresh_interval: int = 4
layers: list[HiPAttentionPerLayerConfig] = field(default_factory=lambda: [
HiPAttentionPerLayerConfig(parsed_json={'second_stage_k': 4096, 'sliding_window_size': 8192, 'sink_token_size': 8192}),
HiPAttentionPerLayerConfig(),
Expand All @@ -77,18 +86,36 @@ def __post_init__(self, parsed_json: dict | None):
if parsed_json is not None:
if 'apply_v_dot' in parsed_json:
self.apply_v_dot = parsed_json['apply_v_dot']
parsed_json.pop('apply_v_dot')
if 'dense_layers' in parsed_json:
self.dense_layers = parsed_json['dense_layers']
parsed_json.pop('dense_layers')
if 'prefill_always_dense' in parsed_json:
self.prefill_always_dense = parsed_json['prefill_always_dense']
parsed_json.pop('prefill_always_dense')
if 'decode_always_dense' in parsed_json:
self.decode_always_dense = parsed_json['decode_always_dense']
parsed_json.pop('decode_always_dense')
if 'force_dense' in parsed_json:
self.force_dense = parsed_json['force_dense']
parsed_json.pop('force_dense')
if 'prefill_dense_threshold' in parsed_json:
self.prefill_dense_threshold = parsed_json['prefill_dense_threshold']
parsed_json.pop('prefill_dense_threshold')
if 'block_sparse_block_size_q' in parsed_json:
self.block_sparse_block_size_q = parsed_json['block_sparse_block_size_q']
parsed_json.pop('block_sparse_block_size_q')
if 'metadata_cache_max_batch_size' in parsed_json:
self.metadata_cache_max_batch_size = parsed_json['metadata_cache_max_batch_size']
parsed_json.pop('metadata_cache_max_batch_size')
if 'mask_refresh_interval' in parsed_json:
self.mask_refresh_interval = parsed_json['mask_refresh_interval']
parsed_json.pop('mask_refresh_interval')
if 'layers' in parsed_json:
self.layers = [
HiPAttentionPerLayerConfig(parsed_json=layer)
for layer in parsed_json['layers']
]
parsed_json.pop('layers')
if parsed_json:
raise ValueError(f'Unknown keys in json: {parsed_json.keys()}')
Original file line number Diff line number Diff line change
@@ -1,2 +1,229 @@
class HiPCudaGraphRunner:
pass
from __future__ import annotations

import bisect
from typing import TYPE_CHECKING, Callable

import torch
import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import graph_capture

from sglang.srt.layers.logits_processor import (
LogitsMetadata,
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner, patch_model, clamp_position

if TYPE_CHECKING:
from sglang.srt.model_executor.hip_model_runner import HiPModelRunner


class HiPCudaGraphRunner(CudaGraphRunner):

def __init__(self, model_runner: "HiPModelRunner"):
super().__init__(model_runner)

def can_run(self, forward_batch: ForwardBatch):
use_cached_mask = forward_batch.hip_use_cached_mask

if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and (max_num_tokens, use_cached_mask) in self.graphs)
if self.disable_padding
else max_num_tokens <= self.max_bs
)
else:
is_bs_supported = (
(forward_batch.batch_size, use_cached_mask) in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)

# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported = (
torch.all(forward_batch.encoder_lens > 0)
if self.is_encoder_decoder
else True
)
return is_bs_supported and is_encoder_lens_supported

def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
capture_bs = (
tqdm.tqdm(self.capture_bs)
if get_tensor_model_parallel_rank() == 0
else self.capture_bs
)
for bs in capture_bs:
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
bs,
self.model_runner.tp_group,
) as forward:
for use_cached_mask in [False, True]:
(
graph,
output_buffers,
) = self.capture_one_batch_size(bs, forward, use_cached_mask)
self.graphs[(bs, use_cached_mask)] = graph
self.output_buffers[(bs, use_cached_mask)] = output_buffers

def capture_one_batch_size(self, bs: int, forward: Callable, hip_use_cached_mask: bool = False):
graph = torch.cuda.CUDAGraph()
stream = self.stream

# Common inputs
input_ids = self.input_ids[:bs]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:bs]
if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs]
else:
encoder_lens = None

seq_lens_sum = seq_lens.sum().item()
mrope_positions = self.mrope_positions[:, :bs]

if self.enable_dp_attention:
global_num_tokens = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
else:
global_num_tokens = None
gathered_buffer = None

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
req_pool_indices,
seq_lens,
encoder_lens,
)

# Run and capture
def run_once():
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
hip_metadata_cache_pool=self.model_runner.hip_metadata_cache_pool,
hip_use_cached_mask=hip_use_cached_mask,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer,
)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits

for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()

run_once()

torch.cuda.synchronize()
self.model_runner.tp_group.barrier()

torch.cuda.synchronize()
self.model_runner.tp_group.barrier()

with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
out = run_once()

torch.cuda.synchronize()
self.model_runner.tp_group.barrier()

self.graph_memory_pool = graph.pool()
return graph, out

def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size

# Pad
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.out_cache_loc.zero_()

# Common inputs
self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)

# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs,
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs),
self.encoder_lens,
)

# Replay
key = (bs, forward_batch.hip_use_cached_mask)
self.graphs[key].replay()
next_token_logits = self.output_buffers[key][:raw_bs]

# Extract logprobs
if forward_batch.return_logprob:
logits_metadata = LogitsMetadata(
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
next_token_logprobs = (
LogitsProcessor.compute_temp_top_p_normalized_logprobs(
next_token_logits, logits_metadata
)
)
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
next_token_logprobs=next_token_logprobs,
)
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob:
(
logits_output.output_top_logprobs_val,
logits_output.output_top_logprobs_idx,
) = LogitsProcessor.get_top_logprobs(
next_token_logprobs, logits_metadata
)[
2:4
]
else:
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
)

return logits_output
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from sglang.srt.model_executor.hip_model_runner import HiPModelRunner
from sglang.srt.layers.attention.hip_attention.hip_config import HiPAttentionConfig

# from hip.models.hip_attention.attention2_draft_sampling_extend import dual_stage_quadratic_hip_attention
# from hip import HiPAttentionArgs
from hip.models.hip_attention.gen3.attention_extend import dual_stage_quadratic_hip_attention
from hip.models.hip_attention.gen3.attention_metadata import HiPAttentionArgs
from hip.models.hip_attention.gen3.uvm_gpu_cache import HiPOffloadCache
Expand Down Expand Up @@ -260,7 +258,7 @@ def forward_extend(
)

logger.debug(f'HiP attention is used in prompting (layer {layer.layer_id})!', stacklevel=0)

is_offload_cache = isinstance(forward_batch.token_to_kv_pool, MHATokenToHiPOffloadKVPool)

if is_offload_cache:
Expand All @@ -269,10 +267,10 @@ def forward_extend(
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
layer,
cache_loc,
k,
v,
async_copy=False
)
k_cache = v_cache = None
Expand Down Expand Up @@ -337,7 +335,6 @@ def forward_decode(
else forward_batch.encoder_out_cache_loc
)


require_dense = (
layer.layer_id in self.hip_config.dense_layers or
self.hip_config.decode_always_dense or
Expand All @@ -364,9 +361,10 @@ def forward_decode(
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
offload_cache = None

metadata = forward_batch.hip_metadata_cache_pool.get_hip_metadata_cache(
layer.layer_id, q.shape[0], forward_batch.batch_size)
#metadata = None
metadata = None
if forward_batch.hip_use_cached_mask:
metadata = forward_batch.hip_metadata_cache_pool.get_hip_metadata_cache(
layer.layer_id, q.shape[0], forward_batch.batch_size)

o, metadata = self.forward_paged_hip(
query=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
Expand All @@ -387,19 +385,8 @@ def forward_decode(
is_dense=require_dense,
)

#print("q shape", q.shape, layer.tp_q_head_num, layer.head_dim)
#print("k_cache shape", k_cache.shape)
#print("v_cache shape", v_cache.shape)
#print("positions", forward_batch.positions)
#print("seq_lens", forward_batch.seq_lens)
#print("metadata")
#print("indices", metadata.indices.shape)
#print("ks", metadata.ks.shape)
#print("ks_count", metadata.ks_count.shape)
#print("ks_start_end", metadata.ks_start_end.shape)

forward_batch.hip_metadata_cache_pool.set_hip_metadata_cache(
layer.layer_id, q.shape[0], forward_batch.batch_size, cache_loc, metadata)
layer.layer_id, q.shape[0], forward_batch.batch_size, metadata)

return o.view(-1, layer.tp_q_head_num * layer.head_dim)

Expand Down Expand Up @@ -436,7 +423,7 @@ def forward_paged_hip(
dst_seq_len = N // batch_size

query = query.view(batch_size, dst_seq_len, num_heads, hidden_dims)

if k_cache is not None:
N_PAGE, num_heads_kv, hidden_dims_kv = k_cache.shape
assert v_cache.shape == k_cache.shape
Expand Down Expand Up @@ -471,7 +458,7 @@ def forward_paged_hip(
rope_sin=layer.rope_sin,

logit_softcap=layer.logit_cap,

second_stage_k=layer_config.second_stage_k,
stages=layer_config.stages,
model_context_length=layer.orig_context_len,
Expand Down
Loading

0 comments on commit 422415b

Please sign in to comment.