From 2295999cdb25cac96e1565bb1860bfd358c26672 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 30 Dec 2023 23:42:13 -0500 Subject: [PATCH] [SLM] Batched Llama This PR introduces the batched llama modeling with Paged KV cache in SLM flow. --- cpp/llm_chat.cc | 19 +- python/mlc_chat/chat_module.py | 2 + python/mlc_chat/cli/gen_config.py | 7 + python/mlc_chat/help.py | 4 + python/mlc_chat/interface/compile.py | 19 +- python/mlc_chat/interface/gen_config.py | 3 + python/mlc_chat/model/llama/llama_model.py | 246 +++++++++++++++++---- python/mlc_chat/model/tir_inventory.py | 100 +++++++++ python/mlc_chat/model/utils/kv_cache.py | 154 +++++++++++++ tests/python/model/test_kv_cache.py | 141 ++++++++++++ 10 files changed, 643 insertions(+), 52 deletions(-) create mode 100644 python/mlc_chat/model/tir_inventory.py create mode 100644 python/mlc_chat/model/utils/kv_cache.py create mode 100644 tests/python/model/test_kv_cache.py diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 3d4456e3a0..dd2cdaef65 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -449,7 +449,7 @@ class LLMChat { this->sliding_window_size_ = config["sliding_window_size"].get(); CHECK(this->sliding_window_size_ > 0 || this->sliding_window_size_ == -1) << "Sliding window size needs to be -1 or positive"; - CHECK(config.count("prefill_chunk_size")) + CHECK(config.count("prefill_chunk_size") || this->sliding_window_size_ == -1) << "Need to specify chunk size if using sliding window attention."; } // to be removed after SLM migration @@ -460,7 +460,7 @@ class LLMChat { this->sliding_window_size_ = config["sliding_window"].get(); CHECK(this->sliding_window_size_ > 0 || this->sliding_window_size_ == -1) << "Sliding window size needs to be -1 or positive"; - CHECK(config.count("prefill_chunk_size")) + CHECK(config.count("prefill_chunk_size") || this->sliding_window_size_ == -1) << "Need to specify chunk size if using sliding window attention."; } if (config.count("prefill_chunk_size")) { @@ -579,7 +579,7 @@ class LLMChat { // Step 6. KV cache creation. this->kv_cache_ = ft_.create_kv_cache_func_(); // Step 7. Pre-allocate fixed size ndarray - this->temperature_arr_ = NDArray::Empty({}, DataType::Float(32), device_); + this->temperature_arr_ = NDArray::Empty({1}, DataType::Float(32), device_); float temperature = static_cast(this->temperature_); this->temperature_arr_.CopyFromBytes(&temperature, sizeof(float)); if (ft_.use_disco) { @@ -1084,7 +1084,7 @@ class LLMChat { CHECK(generation_config["temperature"].is()); *gen_temperature = generation_config["temperature"].get(); - *gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_); + *gen_temperature_arr = NDArray::Empty({1}, DataType::Float(32), device_); float temperature_cast = static_cast(*gen_temperature); gen_temperature_arr->CopyFromBytes(&temperature_cast, sizeof(float)); } else { @@ -1334,7 +1334,16 @@ class LLMChat { NDArray Softmax(NDArray input, NDArray temperature_arr) { NDArray ret; - ret = ft_.softmax_func_(input, temperature_arr); + try { + ret = ft_.softmax_func_(input, temperature_arr); + } catch (const dmlc::Error& e) { + // This branch is for compatibility: + // The old softmax function takes temperature arr with shape (), + // and the new softmax func takes temperature arr with shape (1,). + // Remove this branch after updating all prebuilt model libraries. + temperature_arr = temperature_arr.CreateView({}, temperature_arr->dtype); + ret = ft_.softmax_func_(input, temperature_arr); + } return ret; } diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 9762d751a1..18005be52d 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -186,6 +186,8 @@ class ChatConfig: # pylint: disable=too-many-instance-attributes tensor_parallel_shards: Optional[int] = None use_presharded_weights: Optional[bool] = None max_window_size: Optional[int] = None + context_window_size: Optional[int] = None + sliding_window_size: Optional[int] = None @classmethod def _from_json(cls, json_obj: dict): diff --git a/python/mlc_chat/cli/gen_config.py b/python/mlc_chat/cli/gen_config.py index 4ff09b5a8a..dd6848499d 100644 --- a/python/mlc_chat/cli/gen_config.py +++ b/python/mlc_chat/cli/gen_config.py @@ -76,6 +76,12 @@ def _parse_output(path: Union[str, Path]) -> Path: default=None, help=HELP["tensor_parallel_shards"] + ' (default: "%(default)s")', ) + parser.add_argument( + "--max-batch-size", + type=int, + default=80, + help=HELP["max_batch_size"] + ' (default: "%(default)s")', + ) parser.add_argument( "--output", "-o", @@ -95,5 +101,6 @@ def _parse_output(path: Union[str, Path]) -> Path: prefill_chunk_size=parsed.prefill_chunk_size, attention_sink_size=parsed.attention_sink_size, tensor_parallel_shards=parsed.tensor_parallel_shards, + max_batch_size=parsed.max_batch_size, output=parsed.output, ) diff --git a/python/mlc_chat/help.py b/python/mlc_chat/help.py index 55c019bbfe..dc016e8629 100644 --- a/python/mlc_chat/help.py +++ b/python/mlc_chat/help.py @@ -101,6 +101,10 @@ "attention_sink_size": """ (Experimental) The number of stored sinks. Only supported on Mistral yet. By default, the number of sinks is 4. This flag subjects to future refactoring. +""".strip(), + "max_batch_size": """ +The maximum allowed batch size set for batch prefill/decode function. +By default, the maximum batch size is 80. """.strip(), """tensor_parallel_shards""": """ Number of shards to split the model into in tensor parallelism multi-gpu inference. diff --git a/python/mlc_chat/interface/compile.py b/python/mlc_chat/interface/compile.py index 06a761c812..e333c60396 100644 --- a/python/mlc_chat/interface/compile.py +++ b/python/mlc_chat/interface/compile.py @@ -152,16 +152,17 @@ def _apply_preproc_to_params( def _compile(args: CompileArgs, model_config: ConfigBase): def _get_variable_bounds(model_config) -> Dict[str, int]: + variable_bounds = {"seq_len": model_config.prefill_chunk_size} if hasattr(model_config, "sliding_window_size"): - return { - "seq_len": model_config.prefill_chunk_size, - "rolling_cache_len": model_config.sliding_window_size, - "kv_seq_len": model_config.sliding_window_size + model_config.prefill_chunk_size, - } - return { - "seq_len": model_config.prefill_chunk_size, - "total_seq_len": model_config.context_window_size, - } + variable_bounds["rolling_cache_len"] = model_config.sliding_window_size + variable_bounds["kv_seq_len"] = ( + model_config.sliding_window_size + model_config.prefill_chunk_size, + ) + else: + variable_bounds["total_seq_len"] = model_config.context_window_size + if "max_batch_size" in model_config.kwargs: + variable_bounds["batch_size"] = model_config.kwargs["max_batch_size"] + return variable_bounds def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]: return { diff --git a/python/mlc_chat/interface/gen_config.py b/python/mlc_chat/interface/gen_config.py index ebf81b81c5..ed3aa10e40 100644 --- a/python/mlc_chat/interface/gen_config.py +++ b/python/mlc_chat/interface/gen_config.py @@ -33,6 +33,7 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes prefill_chunk_size: int attention_sink_size: int tensor_parallel_shards: int + max_batch_size: int # Control the behavior of the runtime mean_gen_len: int = None max_gen_len: int = None @@ -79,6 +80,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b prefill_chunk_size: Optional[int], attention_sink_size: Optional[int], tensor_parallel_shards: Optional[int], + max_batch_size: int, output: Path, ): """Entrypoint of MLC Chat configuration generation.""" @@ -101,6 +103,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b prefill_chunk_size=model_config.prefill_chunk_size, attention_sink_size=getattr(model_config, "attention_sink_size", -1), tensor_parallel_shards=model_config.tensor_parallel_shards, + max_batch_size=max_batch_size, conv_template=conv_template, ) # Step 2. Load `generation_config.json` and `config.json` for text-generation related configs diff --git a/python/mlc_chat/model/llama/llama_model.py b/python/mlc_chat/model/llama/llama_model.py index 73826f3e7c..fc3631be57 100644 --- a/python/mlc_chat/model/llama/llama_model.py +++ b/python/mlc_chat/model/llama/llama_model.py @@ -15,6 +15,8 @@ from mlc_chat.support.config import ConfigBase from mlc_chat.support.style import bold +from ..utils.kv_cache import PagedKVCache + logger = logging.getLogger(__name__) @@ -102,31 +104,41 @@ def __init__(self, config: LlamaConfig): bias=False, ) self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False) + # KV cache for single sequence self.k_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) self.v_cache = nn.KVCache(config.context_window_size, [self.num_kv_heads, self.head_dim]) - def forward( # pylint: disable=too-many-locals + def forward( # pylint: disable=too-many-locals,too-many-arguments self, hidden_states: Tensor, - attention_mask: Tensor, - total_seq_len: tir.Var, + attention_mask: Optional[Tensor], + total_seq_len: Optional[tir.Var], + paged_kv_cache: Optional[PagedKVCache], + layer_id: int, ): d, h_q, h_kv, t = self.head_dim, self.num_q_heads, self.num_kv_heads, total_seq_len b, s, _ = hidden_states.shape - assert b == 1, "Only support batch size 1 at this moment." - # Step 1. QKV Projection + # QKV Projection qkv = self.qkv_proj(hidden_states) qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d)) - # Step 2. Apply QK rotary embedding - q, k, v = op_ext.llama_rope(qkv, t, self.rope_theta, h_q, h_kv) - # Step 3. Query and update KVCache - self.k_cache.append(op.squeeze(k, axis=0)) - self.v_cache.append(op.squeeze(v, axis=0)) - k = self.k_cache.view(t) - v = self.v_cache.view(t) - # Step 4. Compute softmax(Q @ K^T / sqrt(d)) @ V - output = op_ext.attention(q, k, v, casual_mask=attention_mask) - # Step 5. Apply output projection + + if paged_kv_cache is None: + # Single sequence attention. + assert attention_mask is not None + assert t is not None + # Step 1. Apply QK rotary embedding + q, k, v = op_ext.llama_rope(qkv, t, self.rope_theta, h_q, h_kv) + # Step 2. Query and update KVCache + self.k_cache.append(op.squeeze(k, axis=0)) + self.v_cache.append(op.squeeze(v, axis=0)) + k = self.k_cache.view(t) + v = self.v_cache.view(t) + # Step 3. Compute softmax(Q @ K^T / sqrt(d)) @ V + output = op_ext.attention(q, k, v, casual_mask=attention_mask) + else: + # Batch attention. + q, k, v = op.split(qkv, indices_or_sections=[h_q, h_q + h_kv], axis=2) + output = op.reshape(paged_kv_cache.attention(layer_id, q, k, v), (b, s, h_q * d)) return self.o_proj(output) @@ -156,13 +168,26 @@ def _set(layer, hint): self.tensor_parallel_shards = config.tensor_parallel_shards _set_tp() - def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): + def forward( # pylint: disable=too-many-arguments + self, + hidden_states: Tensor, + attention_mask: Optional[Tensor], + total_seq_len: Optional[tir.Var], + paged_kv_cache: Optional[PagedKVCache], + layer_id: int, + ): def _apply_residual(out, residual): if self.tensor_parallel_shards > 1: return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, "sum") return out + residual - out = self.self_attn(self.input_layernorm(hidden_states), attention_mask, total_seq_len) + out = self.self_attn( + self.input_layernorm(hidden_states), + attention_mask, + total_seq_len, + paged_kv_cache, + layer_id, + ) hidden_states = _apply_residual(out, residual=hidden_states) out = self.mlp(self.post_attention_layernorm(hidden_states)) hidden_states = _apply_residual(out, residual=hidden_states) @@ -179,22 +204,45 @@ def __init__(self, config: LlamaConfig): self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) self.tensor_parallel_shards = config.tensor_parallel_shards - def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): - if self.tensor_parallel_shards > 1: - inputs = op.ccl_broadcast_from_worker0(inputs) - hidden_states = self.embed_tokens(inputs) - for layer in self.layers: - hidden_states = layer(hidden_states, attention_mask, total_seq_len) + def forward( # pylint: disable=too-many-arguments + self, + input_ids: Optional[Tensor], + input_embeds: Optional[Tensor], + total_seq_len: Optional[tir.Var], + attention_mask: Optional[Tensor], + paged_kv_cache: Optional[PagedKVCache], + ): + if input_ids is not None: + assert input_embeds is None + if self.tensor_parallel_shards > 1: + input_ids = op.ccl_broadcast_from_worker0(input_ids) + hidden_states = self.embed_tokens(input_ids) + else: + assert input_embeds is not None + if self.tensor_parallel_shards > 1: + input_embeds = op.ccl_broadcast_from_worker0(input_embeds) + hidden_states = input_embeds + + for layer_id, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, attention_mask, total_seq_len, paged_kv_cache, layer_id + ) hidden_states = self.norm(hidden_states) return hidden_states -class LlamaForCasualLM(nn.Module): +class LlamaForCasualLM(nn.Module): # pylint: disable=too-many-instance-attributes def __init__(self, config: LlamaConfig): self.model = LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.num_hidden_layers = config.num_hidden_layers + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size self.vocab_size = config.vocab_size self.rope_theta = config.position_embedding_base + self.tensor_parallel_shards = config.tensor_parallel_shards self.dtype = "float32" def to(self, dtype: Optional[str] = None): @@ -202,21 +250,38 @@ def to(self, dtype: Optional[str] = None): if dtype is not None: self.dtype = dtype - def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): + def forward( # pylint: disable=too-many-arguments + self, + input_ids: Optional[Tensor] = None, + input_embeds: Optional[Tensor] = None, + total_seq_len: Optional[tir.Var] = None, + attention_mask: Optional[Tensor] = None, + paged_kv_cache: Optional[PagedKVCache] = None, + logit_positions: Optional[Tensor] = None, + ): op_ext.configure() def _index(x: te.Tensor): # x[:-1,:] b, s, d = x.shape return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name="index") - hidden_states = self.model(inputs, total_seq_len, attention_mask) - hidden_states = op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + hidden_states = self.model( + input_ids, input_embeds, total_seq_len, attention_mask, paged_kv_cache + ) + hidden_states = ( + op.tensor_expr_op(_index, name_hint="index", args=[hidden_states]) + if logit_positions is None + else op.take(hidden_states, logit_positions, axis=1) + ) logits = self.lm_head(hidden_states) if logits.dtype != "float32": logits = logits.astype("float32") return logits - def prefill(self, inputs: Tensor, total_seq_len: tir.Var): + def embed(self, input_ids: Tensor): + return self.model.embed_tokens(input_ids) + + def prefill(self, input_ids: Tensor, total_seq_len: tir.Var): def _attention_mask(batch_size, seq_len, total_seq_len): return te.compute( (batch_size, 1, seq_len, total_seq_len), @@ -228,31 +293,110 @@ def _attention_mask(batch_size, seq_len, total_seq_len): name="attention_mask_prefill", ) - batch_size, seq_len = inputs.shape + batch_size, seq_len = input_ids.shape attention_mask = op.tensor_expr_op( _attention_mask, name_hint="attention_mask_prefill", args=[batch_size, seq_len, total_seq_len], ) - return self.forward(inputs, total_seq_len, attention_mask) + return self.forward( + input_ids=input_ids, total_seq_len=total_seq_len, attention_mask=attention_mask + ) - def decode(self, inputs: Tensor, total_seq_len: tir.Var): - batch_size, seq_len = inputs.shape + def decode(self, input_ids: Tensor, total_seq_len: tir.Var): + batch_size, seq_len = input_ids.shape attention_mask = op.full( shape=[batch_size, 1, seq_len, total_seq_len], fill_value=tir.max_value(self.dtype), dtype=self.dtype, ) - return self.forward(inputs, total_seq_len, attention_mask) + return self.forward( + input_ids=input_ids, total_seq_len=total_seq_len, attention_mask=attention_mask + ) + + def prefill_batch( + self, input_embeds: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache + ): + logits = self.forward( + input_embeds=input_embeds, + logit_positions=logit_positions, + paged_kv_cache=paged_kv_cache, + ) + return logits, paged_kv_cache + + def decode_batch(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache): + logits = self.forward(input_embeds=input_embeds, paged_kv_cache=paged_kv_cache) + return logits, paged_kv_cache def softmax_with_temperature(self, logits: Tensor, temperature: Tensor): - return op.softmax(logits / temperature, axis=-1) + return op.softmax(logits / op.reshape(temperature, (temperature.shape[0], 1, 1)), axis=-1) + + def create_paged_kv_cache( + self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, page_size: tir.Var + ) -> PagedKVCache: + import tvm # pylint: disable=import-outside-toplevel + from tvm import relax # pylint: disable=import-outside-toplevel + + from ..tir_inventory import ( # pylint: disable=import-outside-toplevel + kv_cache_debug_get_kv, + kv_cache_transpose_append, + ) + + attn_prefill = tvm.get_global_func( + "paged_kv_cache.attention_kernel_prefill", allow_missing=True + ) + if attn_prefill is None: + return op.zeros((), self.dtype) + + bb = relax.BlockBuilder.current() + num_qo_heads = self.num_attention_heads // self.tensor_parallel_shards + num_kv_heads = self.num_key_value_heads // self.tensor_parallel_shards + return PagedKVCache.create( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + page_size=page_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=num_qo_heads, + num_key_value_heads=num_kv_heads, + head_dim=self.head_dim, + rope_scale=1, + rope_theta=self.rope_theta, + dtype=self.dtype, + f_transpose_append=bb.add_func( + kv_cache_transpose_append(num_kv_heads, self.head_dim, self.dtype), + "kv_cache_transpose_append", + ), + f_attn_prefill="paged_kv_cache.attention_kernel_prefill", + f_attn_decode="paged_kv_cache.attention_kernel_decode", + f_attn_prefill_ragged="flashinfer.attention_kernel_prefill_with_ragged_kv_cache", + f_attn_prefill_begin_forward="paged_kv_cache.attention_kernel_prefill_begin_forward", + f_attn_prefill_end_forward="paged_kv_cache.attention_kernel_prefill_end_forward", + f_attn_decode_begin_forward="paged_kv_cache.attention_kernel_decode_begin_forward", + f_attn_decode_end_forward="paged_kv_cache.attention_kernel_decode_end_forward", + f_attn_prefill_ragged_begin_forward="flashinfer.attention_kernel_prefill_with_ragged_kv_cache_begin_forward", # pylint: disable=line-too-long + f_attn_prefill_ragged_end_forward="flashinfer.attention_kernel_prefill_with_ragged_kv_cache_end_forward", # pylint: disable=line-too-long + f_attn_merge_state="flashinfer.merge_state_in_place", + f_attn_apply_rope="flashinfer.batch_qk_apply_rotary_in_place", + f_debug_get_kv=bb.add_func( + kv_cache_debug_get_kv( + self.num_hidden_layers, num_kv_heads, self.head_dim, self.dtype + ), + "kv_cache_debug_get_kv", + ), + ) def get_default_spec(self): batch_size = 1 mod_spec = { + "embed": { + "input_ids": nn.spec.Tensor([1, "seq_len"], "int32"), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "prefill": { - "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), + "input_ids": nn.spec.Tensor([batch_size, "seq_len"], "int32"), "total_seq_len": int, "$": { "param_mode": "packed", @@ -260,16 +404,42 @@ def get_default_spec(self): }, }, "decode": { - "inputs": nn.spec.Tensor([batch_size, 1], "int32"), + "input_ids": nn.spec.Tensor([batch_size, 1], "int32"), "total_seq_len": int, "$": { "param_mode": "packed", "effect_mode": "packed", }, }, + "prefill_batch": { + "input_embeds": nn.spec.Tensor([1, "seq_len", self.hidden_size], self.dtype), + "logit_positions": nn.spec.Tensor(["batch_size"], "int32"), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, + "decode_batch": { + "input_embeds": nn.spec.Tensor(["batch_size", 1, self.hidden_size], self.dtype), + "paged_kv_cache": nn.spec.Object(object_type=PagedKVCache), + "$": { + "param_mode": "packed", + "effect_mode": "none", + }, + }, "softmax_with_temperature": { - "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), - "temperature": nn.spec.Tensor([], "float32"), + "logits": nn.spec.Tensor(["batch_size", 1, self.vocab_size], "float32"), + "temperature": nn.spec.Tensor(["batch_size"], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "page_size": int, "$": { "param_mode": "none", "effect_mode": "none", diff --git a/python/mlc_chat/model/tir_inventory.py b/python/mlc_chat/model/tir_inventory.py new file mode 100644 index 0000000000..73c39c080c --- /dev/null +++ b/python/mlc_chat/model/tir_inventory.py @@ -0,0 +1,100 @@ +"""Inventory of common TensorIR functions""" +# pylint: disable=too-many-locals +from tvm.script import tir as T + + +def kv_cache_transpose_append(num_key_value_heads, head_dim, dtype): + """Return the TIR function that appends new k/v data to PagedKVCache.""" + + @T.prim_func + def tir_kv_cache_transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_position_map: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + ntoken = T.SizeVar("ntoken", "int64") + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + + pages = T.match_buffer( + var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype + ) + k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype) + v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype) + position_map = T.match_buffer(var_position_map, (ntoken,), "int32") + + for global_pos, h, f in T.grid(ntoken, num_key_value_heads, head_dim): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes( + pages[ + position_map[vgpos] // page_size, 0, vh, position_map[vgpos] % page_size, vf + ] + ) + position: T.int32 = position_map[vgpos] # type: ignore + pages[ + T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vf + ] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + T.reads(position_map[vgpos], k_data[vgpos, vh, vf]) + T.writes( + pages[ + position_map[vgpos] // page_size, 1, vh, position_map[vgpos] % page_size, vf + ] + ) + position: T.int32 = position_map[vgpos] # type: ignore + pages[ + T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vf + ] = v_data[vgpos, vh, vf] + + return tir_kv_cache_transpose_append + + +def kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype): + """Return the TIR function that fetches the k/v data on given positions and layer.""" + + @T.prim_func + def tir_kv_cache_debug_get_kv( + var_pages: T.handle, + var_position_map: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + layer_id: T.int64, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + seqlen = T.SizeVar("seqlen", "int64") + page_size = T.SizeVar("page_size", "int64") + num_pages = T.int64() + + pages = T.match_buffer( + var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype + ) + position_map = T.match_buffer(var_position_map, (seqlen,), "int32") + k_data = T.match_buffer( + var_k_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype + ) + v_data = T.match_buffer( + var_v_data, (num_hidden_layers, seqlen, num_key_value_heads, head_dim), dtype + ) + + for p, h, d in T.grid(seqlen, num_key_value_heads, head_dim): + with T.block("copy0"): + vp, vh, vd = T.axis.remap("SSS", [p, h, d]) + T.reads( + position_map[vp], + pages[position_map[vp] // page_size, 0:2, vh, position_map[vp] % page_size, vd], + ) + T.writes(k_data[layer_id, vp, vh, vd], v_data[layer_id, vp, vh, vd]) + position: T.int32 = position_map[vp] + k_data[layer_id, vp, vh, vd] = pages[ + T.floordiv(position, page_size), 0, vh, T.floormod(position, page_size), vd + ] + v_data[layer_id, vp, vh, vd] = pages[ + T.floordiv(position, page_size), 1, vh, T.floormod(position, page_size), vd + ] + + return tir_kv_cache_debug_get_kv diff --git a/python/mlc_chat/model/utils/kv_cache.py b/python/mlc_chat/model/utils/kv_cache.py new file mode 100644 index 0000000000..f23d8726c9 --- /dev/null +++ b/python/mlc_chat/model/utils/kv_cache.py @@ -0,0 +1,154 @@ +from typing import Union + +from tvm import relax as rx +from tvm import tir +from tvm.ir import GlobalVar +from tvm.relax.frontend.nn import Object, Tensor + + +class PagedKVCache(Object): + """The Paged KV Cache used in LLM batching for efficient attention computation.""" + + @staticmethod + def create( + max_batch_size: tir.Var, + max_total_seq_len: tir.Var, + page_size: tir.Var, + num_hidden_layers: int, + num_attention_heads: int, + num_key_value_heads: int, + head_dim: int, + rope_scale: int, + rope_theta: int, + dtype: str, + f_transpose_append: Union[str, GlobalVar], + f_attn_prefill: Union[str, GlobalVar], + f_attn_decode: Union[str, GlobalVar], + f_attn_prefill_ragged: Union[str, GlobalVar], + f_attn_prefill_begin_forward: Union[str, GlobalVar], + f_attn_prefill_end_forward: Union[str, GlobalVar], + f_attn_decode_begin_forward: Union[str, GlobalVar], + f_attn_decode_end_forward: Union[str, GlobalVar], + f_attn_prefill_ragged_begin_forward: Union[str, GlobalVar], + f_attn_prefill_ragged_end_forward: Union[str, GlobalVar], + f_attn_merge_state: Union[str, GlobalVar], + f_attn_apply_rope: Union[str, GlobalVar], + f_debug_get_kv: Union[str, GlobalVar], + name: str = "paged_kv_cache", + ) -> "PagedKVCache": + """Create a paged KV cache object. + + Parameters + ---------- + max_batch_size : tir.Var + The maximum allowed batch size of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + max_total_seq_len : tir.Var + The maximum allowed total sequence length of the KV cache. + It is a symbolic variable whose concrete value is specified + at runtime. + page_size : tir.Var + The size (a.k.a. number of tokens) of each page. + It is a symbolic variable whose concrete value is specified + at runtime. + rope_scale : int + The scale of rotary position embedding. + rope_theta : int + The base of rotary position embedding. + f_transpose_append : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule that transposes + and appends new k/v data to KV cache. + f_attn_prefill : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule that computes + batch prefill attention with input q data and in-cache k/v data. + f_attn_decode : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule that computes + batch decode attention with input q data and in-cache k/v data. + f_attn_prefill_ragged : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule that computes + batch prefill with input input q/k/v data. + f_attn_prefill_begin_forward : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule for batch prefill + function pre-process. + f_attn_prefill_end_forward : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule for batch prefill + function post-process. + f_attn_decode_begin_forward : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule for batch decode + function pre-process. + f_attn_decode_end_forward : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule for batch decode + function post-process. + f_attn_prefill_ragged_begin_forward : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule for batch prefill + function pre-process. + f_attn_prefill_ragged_end_forward : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule for batch prefill + function post-process. + f_attn_merge_state : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule that merges + attention outputs and scores. + f_attn_apply_rope : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule that applies + rotary position embedding to input q/k data. + f_debug_get_kv : Union[str, GlobalVar] + The external function name or GlobalVar in IRModule that returns + in-cache k/v data for a specified sequence. + """ + + def _convert_func(func: Union[str, GlobalVar]) -> rx.Expr: + return func if isinstance(func, GlobalVar) else rx.extern(func) + + return PagedKVCache( + _expr=rx.Call( + rx.extern("vm.builtin.paged_attention_kv_cache_create"), + args=[ + rx.ShapeExpr([max_batch_size, max_total_seq_len, page_size]), + rx.PrimValue(num_hidden_layers), + rx.PrimValue(num_attention_heads), + rx.PrimValue(num_key_value_heads), + rx.PrimValue(head_dim), + rx.PrimValue(rope_scale), + rx.PrimValue(rope_theta), + rx.op.zeros((), dtype), + _convert_func(f_transpose_append), + _convert_func(f_attn_prefill), + _convert_func(f_attn_decode), + _convert_func(f_attn_prefill_ragged), + _convert_func(f_attn_prefill_ragged_begin_forward), + _convert_func(f_attn_prefill_ragged_end_forward), + _convert_func(f_attn_prefill_begin_forward), + _convert_func(f_attn_prefill_end_forward), + _convert_func(f_attn_decode_begin_forward), + _convert_func(f_attn_decode_end_forward), + _convert_func(f_attn_apply_rope), + _convert_func(f_attn_merge_state), + _convert_func(f_debug_get_kv), + ], + sinfo_args=[rx.ObjectStructInfo()], + ), + _name=name, + ) + + def attention(self, layer_id: int, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + """Compute attention with the given q/k/v data and in-cache k/v data + on the specified layer. Rotary position embeddings are applied to k/v + within this function. + + - For prefill, the input q and output tensor have shape + (1, total_seq_len, num_attention_heads, head_dim), and the + k/v tensors have shape (1, total_seq_len, num_key_value_heads, head_dim). + - For decode, the input q and output tensor have shape + (batch_size, 1, num_attention_heads, head_dim), and the + k/v tensors have shape (batch_size, 1, num_key_value_heads, head_dim). + """ + return Tensor( + _expr=rx.BlockBuilder.current().emit( + rx.call_dps_packed( + "vm.builtin.paged_attention_kv_cache_attention", + [self._expr, rx.PrimValue(layer_id), q._expr, k._expr, v._expr], + out_sinfo=q._expr.struct_info, + ) + ) + ) diff --git a/tests/python/model/test_kv_cache.py b/tests/python/model/test_kv_cache.py new file mode 100644 index 0000000000..0ddf1f5e48 --- /dev/null +++ b/tests/python/model/test_kv_cache.py @@ -0,0 +1,141 @@ +import tvm +from tvm import tir +from tvm.relax.frontend.nn import core, modules, spec +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T + +from mlc_chat.model.utils.kv_cache import PagedKVCache + + +def test_nn_module_paged_kv_cache(): + @I.ir_module + class Module: + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io: R.Object = R.null_value() + lv: R.Tuple(R.Object) = (_io,) + gv: R.Tuple(R.Object) = lv + R.output(gv) + return gv + + @R.function + def create_paged_kv_cache( + max_batch_size: R.Shape(["max_batch_size_1"]), + max_total_seq_len: R.Shape(["max_total_seq_len_1"]), + page_size: R.Shape(["page_size_1"]), + _io: R.Object, + ) -> R.Tuple(R.Object, R.Tuple(R.Object)): + max_batch_size_1 = T.int64() + max_total_seq_len_1 = T.int64() + page_size_1 = T.int64() + R.func_attr({"num_input": 4}) + with R.dataflow(): + lv2: R.Tensor((), dtype="float16") = R.zeros(R.shape([]), dtype="float16") + paged_kv_cache: R.Object = R.call_packed( + "vm.builtin.paged_attention_kv_cache_create", + R.shape([max_batch_size_1, max_total_seq_len_1, page_size_1]), + R.prim_value(32), + R.prim_value(32), + R.prim_value(32), + R.prim_value(128), + R.prim_value(1), + R.prim_value(10000), + lv2, + R.ExternFunc("kv_cache_transpose_append"), + R.ExternFunc("attention_kernel_prefill"), + R.ExternFunc("attention_kernel_decode"), + R.ExternFunc("attention_kernel_prefill_with_ragged_kv_cache"), + R.ExternFunc("attention_kernel_prefill_with_ragged_kv_cache_begin_forward"), + R.ExternFunc("attention_kernel_prefill_with_ragged_kv_cache_end_forward"), + R.ExternFunc("attention_kernel_prefill_begin_forward"), + R.ExternFunc("attention_kernel_prefill_end_forward"), + R.ExternFunc("attention_kernel_decode_begin_forward"), + R.ExternFunc("attention_kernel_decode_end_forward"), + R.ExternFunc("batch_qk_apply_rotary_in_place"), + R.ExternFunc("merge_state_in_place"), + R.ExternFunc("kv_cache_debug_get_kv"), + sinfo_args=(R.Object,), + ) + gv2: R.Tuple(R.Object, R.Tuple(R.Object)) = paged_kv_cache, (_io,) + R.output(gv2) + return gv2 + + @R.function + def forward( + cache: R.Object, + q: R.Tensor((1, 100, 32, 128), dtype="float16"), + k: R.Tensor((1, 100, 32, 128), dtype="float16"), + v: R.Tensor((1, 100, 32, 128), dtype="float16"), + _io: R.Object, + ) -> R.Tuple(R.Tensor((1, 100, 32, 128), dtype="float16"), R.Tuple(R.Object)): + R.func_attr({"num_input": 5}) + with R.dataflow(): + lv1 = R.call_dps_packed( + "vm.builtin.paged_attention_kv_cache_attention", + (cache, R.prim_value(0), q, k, v), + out_sinfo=R.Tensor((1, 100, 32, 128), dtype="float16"), + ) + gv1: R.Tuple( + R.Tensor((1, 100, 32, 128), dtype="float16"), R.Tuple(R.Object) + ) = lv1, (_io,) + R.output(gv1) + return gv1 + + class PagedKVCacheTest(modules.Module): + def forward( + self, + cache: PagedKVCache, + q: core.Tensor, + k: core.Tensor, + v: core.Tensor, + ) -> core.Tensor: + return cache.attention(0, q, k, v) + + def create_paged_kv_cache( + self, max_batch_size: tir.Var, max_total_seq_len: tir.Var, page_size: tir.Var + ) -> PagedKVCache: + return PagedKVCache.create( + max_batch_size=max_batch_size, + max_total_seq_len=max_total_seq_len, + page_size=page_size, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + head_dim=128, + rope_scale=1, + rope_theta=10000, + dtype="float16", + f_transpose_append="kv_cache_transpose_append", + f_attn_prefill="attention_kernel_prefill", + f_attn_decode="attention_kernel_decode", + f_attn_prefill_ragged="attention_kernel_prefill_with_ragged_kv_cache", + f_attn_prefill_begin_forward="attention_kernel_prefill_begin_forward", + f_attn_prefill_end_forward="attention_kernel_prefill_end_forward", + f_attn_decode_begin_forward="attention_kernel_decode_begin_forward", + f_attn_decode_end_forward="attention_kernel_decode_end_forward", + f_attn_prefill_ragged_begin_forward="attention_kernel_prefill_with_ragged_kv_cache_begin_forward", # pylint: disable=line-too-long + f_attn_prefill_ragged_end_forward="attention_kernel_prefill_with_ragged_kv_cache_end_forward", # pylint: disable=line-too-long + f_attn_merge_state="merge_state_in_place", + f_attn_apply_rope="batch_qk_apply_rotary_in_place", + f_debug_get_kv="kv_cache_debug_get_kv", + ) + + tvm_mod, _ = PagedKVCacheTest().export_tvm( + spec={ + "forward": { + "cache": spec.Object(object_type=PagedKVCache), + "q": spec.Tensor((1, 100, 32, 128), "float16"), + "k": spec.Tensor((1, 100, 32, 128), "float16"), + "v": spec.Tensor((1, 100, 32, 128), "float16"), + }, + "create_paged_kv_cache": { + "max_batch_size": int, + "max_total_seq_len": int, + "page_size": int, + }, + }, + debug=True, + ) + tvm.ir.assert_structural_equal(tvm_mod, Module, True)