diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index dd90a67fb5..b0e47b27a9 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -7,6 +7,27 @@ namespace mlc { namespace llm { namespace { +Conversation ChatML() { + Conversation conv; + conv.name = "chatml"; + conv.roles = {"<|im_start|>user", "<|im_start|>assistant"}; + conv.system = + ("<|im_start|>system A conversation between a user and an LLM-based AI assistant. The " + "assistant gives helpful and honest answers.<|im_end|> "); + conv.messages = {}; + conv.offset = 0; + conv.separator_style = SeparatorStyle::kSepRoleMsg; + conv.seps = {"<|im_end|>", "<|im_end|>"}; + conv.role_msg_sep = "\n"; + conv.role_empty_sep = "\n"; + // TODO(mlc-team): add eos to mlc-chat-config + // and remove eos from stop token setting. + conv.stop_tokens = {2}; + conv.stop_str = "<|im_end|>"; + conv.add_bos = true; + return conv; +} + Conversation LlamaDefault() { Conversation conv; conv.name = "llama_default"; @@ -583,6 +604,7 @@ using ConvFactory = Conversation (*)(); Conversation Conversation::FromTemplate(const std::string& name) { static std::unordered_map factory = { + {"chatml", ChatML}, {"llama_default", LlamaDefault}, {"llama-2", Llama2}, {"mistral_default", MistralDefault}, diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 8416f836fc..1255c18bcc 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -409,10 +409,16 @@ class LLMChat { CHECK(!config.count("max_window_size")) << "Cannot specify both sliding_window and max_window_size."; this->sliding_window_ = config["sliding_window"].get(); + CHECK(this->sliding_window_ > 0) << "Sliding window size needs to be positive"; + CHECK(config.count("sliding_window_chunk_size")) + << "Need to specify chunk size if using sliding window attention."; } if (config.count("sliding_window_chunk_size")) { CHECK(config["sliding_window_chunk_size"].is()); this->sliding_window_chunk_size_ = config["sliding_window_chunk_size"].get(); + CHECK(this->sliding_window_chunk_size_ > 0) + << "Sliding window chunk size needs to be positive"; + CHECK(config.count("sliding_window")) << "Need to specify sliding window size."; } if (config.count("model_name")) { CHECK(config["model_name"].is()); @@ -828,13 +834,8 @@ class LLMChat { NDArray logits_on_device; if (this->sliding_window_ != -1) { // Use chunking if we use sliding window attention (see Mistral paper figure 3). - int64_t sliding_window_chunk_size = this->sliding_window_chunk_size_; - if (this->sliding_window_chunk_size_ == -1) { - // One chunk if chunk size not specified - sliding_window_chunk_size = token_len; - } - for (int64_t begin = 0; begin < token_len; begin += sliding_window_chunk_size) { - int64_t end = std::min(token_len, begin + sliding_window_chunk_size); + for (int64_t begin = 0; begin < token_len; begin += this->sliding_window_chunk_size_) { + int64_t end = std::min(token_len, begin + this->sliding_window_chunk_size_); std::vector chunk = std::vector(prompt_tokens.begin() + begin, prompt_tokens.begin() + end); new_seq_len += static_cast(chunk.size()); diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 8bdd93669b..1e1f8aed82 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -849,7 +849,7 @@ def build_model_from_args(args: argparse.Namespace): mod = mod_transform_before_build(mod, param_manager, args, model_config) if args.num_shards > 1: - # We requires a "create_sharding_info" function for all + # We require a "create_sharding_info" function for all # multi-GPU models, even if they are using pre-sharded # weights. When using pre-sharded weights, the list of # initialization-time transforms to apply is empty. diff --git a/mlc_llm/relax_model/mistral.py b/mlc_llm/relax_model/mistral.py index 1ef00ff577..31ed39fdb5 100644 --- a/mlc_llm/relax_model/mistral.py +++ b/mlc_llm/relax_model/mistral.py @@ -949,6 +949,9 @@ def get_model(args, hf_config): sliding_window_chunk_size=args.sliding_window_chunk_size, ) + assert config.sliding_window != -1 + assert config.sliding_window_chunk_size != -1 + param_manager = ParamManager() bb = relax.BlockBuilder() @@ -962,6 +965,8 @@ def get_model(args, hf_config): max_window_size=config.max_sequence_length, stop_tokens=[2], add_prefix_space=False, + sliding_window=config.sliding_window, + sliding_window_chunk_size=config.sliding_window_chunk_size, ) mod = bb.get() diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 649685c6b0..bcadaa84ba 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -8,13 +8,15 @@ import warnings from dataclasses import asdict, dataclass, fields from enum import Enum -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import tvm from tvm.runtime import disco # pylint: disable=unused-import -from .base import _LIB # pylint: disable=unused-import -from .interface.openai_api import ChatMessage +from . import base # pylint: disable=unused-import + +if TYPE_CHECKING: + from .interface.openai_api import ChatMessage # pylint: disable=line-too-long _PYTHON_GET_STARTED_TUTORIAL_URL = "https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_chat_module_getting_started.ipynb" @@ -41,10 +43,10 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes roles : Optional[List[str]] An array that describes the role names of the user and the model. These names are specific to the model being used. - messages : Optional[List[str]] + messages : Optional[List[List[str]]] The chat history represented as an array of string pairs in the following format: ``[[role_0, msg_0], [role_1, msg_1], ...]``. - offset : Optional[str] + offset : Optional[int] The offset used to begin the chat from the chat history. When offset is not ``0``, ``messages[0:offset-1]`` will be encoded. separator_style : Optional[int] @@ -69,7 +71,7 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes system: Optional[str] = None roles: Optional[List[str]] = None messages: Optional[List[List[str]]] = None - offset: Optional[str] = None + offset: Optional[int] = None separator_style: Optional[int] = None seps: Optional[List[str]] = None role_msg_sep: Optional[str] = None @@ -787,7 +789,7 @@ def __init__( def generate( self, - prompt: Union[str, List[ChatMessage]], + prompt: Union[str, List["ChatMessage"]], generation_config: Optional[GenerationConfig] = None, progress_callback=None, ) -> Union[str, List[str]]: @@ -797,14 +799,18 @@ def generate( Parameters ---------- - prompt : Union[str, List[ChatMessage]] + prompt: Union[str, List[ChatMessage]] The user input prompt, i.e. a question to ask the chat module. It can also be the whole conversation history (list of messages with role and content) - eg: ```[ - ChatMessage(role="user", content="Hello, how are you?"), - ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), - ChatMessage(role="user", content="I'm good too."), - ]``` + eg: + + .. code:: + + [ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ] generation_config: Optional[GenerationConfig] The generation config object to override the ChatConfig generation settings. progress_callback: object @@ -841,8 +847,6 @@ def generate( if (generation_config is not None) and (generation_config.n is not None): num_return_sequences = generation_config.n return_str = False - else: - num_return_sequences = 1 for _ in range(num_return_sequences): self.reset_chat() @@ -1001,7 +1005,7 @@ def _unload(self): def _prefill( self, - input: Union[str, List[ChatMessage]], # pylint: disable=redefined-builtin + input: Union[str, List["ChatMessage"]], # pylint: disable=redefined-builtin decode_next_token: bool = True, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, generation_config: Optional[GenerationConfig] = None, @@ -1014,11 +1018,15 @@ def _prefill( input : Union[str, List[ChatMessage]] The user input prompt, i.e. a question to ask the chat module. It can also be the whole conversation history (list of messages with role and content) - eg: ```[ - ChatMessage(role="user", content="Hello, how are you?"), - ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), - ChatMessage(role="user", content="I'm good too."), - ]``` + eg: + + .. code:: + + [ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ] decode_next_token : bool Whether to decode the next token after prefilling. place_in_prompt: PlaceInPrompt diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 678e924a78..ade62309c5 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -52,39 +52,46 @@ def _attach_auxiliary_methods( mod: IRModule, named_params: List[Tuple[str, nn.Parameter]], args: CompileArgs, - model_config, ) -> None: - def _metadata(): - metadata = { - "quantization": args.quantization.name, - "model_type": args.model.name, - "params": [ - { - "name": name, - "shape": list(param.shape), - "dtype": param.dtype, - } - for name, param in named_params - ], - } + def _get_memory_usage(): + return {str(k): int(v) for k, v in mod.attrs["mlc_llm.memory_usage"].items()} + + def _get_param_info(): + return [ + { + "name": name, + "shape": list(param.shape), + "dtype": param.dtype, + } + for name, param in named_params + ] + + def _emit_metadata(metadata): bb = relax.BlockBuilder() # pylint: disable=invalid-name with bb.function("main", params=[]): bb.emit_func_output(relax.StringImm(json.dumps(metadata))) return bb.get()["main"] - def _attach_variable_bounds(): - for g_var, func in mod.functions_items(): - if isinstance(func, relax.Function): - mod[g_var] = func.with_attr( - "tir_var_upper_bound", - { - "seq_len": model_config.max_sequence_length, - "total_seq_len": model_config.max_sequence_length, - }, - ) + mod["_metadata"] = _emit_metadata( + metadata={ + "quantization": args.quantization.name, + "model_type": args.model.name, + "memory_usage": _get_memory_usage(), + "params": _get_param_info(), + } + ) + - mod["_metadata"] = _metadata() - _attach_variable_bounds() +def _attach_variable_bounds(mod, model_config): + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + mod[g_var] = func.with_attr( + "tir_var_upper_bound", + { + "seq_len": model_config.max_sequence_length, + "total_seq_len": model_config.max_sequence_length, + }, + ) def _compile(args: CompileArgs): @@ -96,10 +103,11 @@ def _compile(args: CompileArgs): mod, named_params = model.export_tvm( spec=model.get_default_spec(), # type: ignore ) - _attach_auxiliary_methods(mod, named_params, args, model_config) logger.info("Running optimizations using TVM Unity") + _attach_variable_bounds(mod, model_config) with args.target: mod = relax.get_pipeline("mlc_llm")(mod) + _attach_auxiliary_methods(mod, named_params, args) logger.info("Generating code using TVM Unity") args.build_func(mod, args) logger.info("Generated: %s", bold(str(args.output))) diff --git a/python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py b/python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py new file mode 100644 index 0000000000..d6f959accf --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py @@ -0,0 +1,77 @@ +"""Memory usage estimation analysis function for Relax functions.""" +from typing import Dict + +import tvm +from tvm import relax +from tvm.ir import IRModule, Op +from tvm.relax.expr_functor import PyExprVisitor, visitor + + +@tvm.transform.module_pass(opt_level=0, name="EstimateMemoryUsage") +class EstimateMemoryUsage: # pylint: disable=too-few-public-methods + """A pass that attaches the memory usage information as an IRModule attribute. + + This pass relies on static analysis on each TVM Relax function in the specific IRModule. + It simply accumulates all memory allocation calls in a function, and does not consider + more dynamic runtime features like control flo "if" or function calls. + """ + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entry point of the pass.""" + lowered_mod = tvm.transform.Sequential( + [ + relax.transform.RewriteDataflowReshape(), + relax.transform.ToNonDataflow(), + relax.transform.RemovePurityChecking(), + relax.transform.CallTIRRewrite(), + relax.transform.StaticPlanBlockMemory(), + ], + name="relax.lower", + )(mod) + usage = _MemoryEstimator().run(lowered_mod) + return mod.with_attr("mlc_llm.memory_usage", usage) + + +@visitor +class _MemoryEstimator(PyExprVisitor): + """The IR visitor which estimates the memory usage of each Relax function.""" + + def __init__(self) -> None: + self.planned_alloc_mem = 0 + self.planned_mem_num = 0 + self._op_alloc_tensor = Op.get("relax.builtin.alloc_tensor") + self._op_alloc_storage = Op.get("relax.memory.alloc_storage") + + def run(self, mod: IRModule) -> Dict[str, int]: + """Entry point of the visitor.""" + result: Dict[str, int] = {} + for global_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.planned_alloc_mem = 0 + self.planned_mem_num = 0 + self.visit_expr(func) + result[global_var.name_hint] = self.planned_alloc_mem + return result + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op == self._op_alloc_tensor: + self._builtin_tensor_alloc(shape=call.args[0], dtype_str=call.args[1].value) + elif call.op == self._op_alloc_storage: + self._storage_alloc(size=call.args[0]) + super().visit_call_(call) + + def _builtin_tensor_alloc(self, shape: relax.Expr, dtype_str: str) -> None: + assert isinstance(shape, relax.ShapeExpr) + size = 1 + for dim_len in shape.values: + if not isinstance(dim_len, tvm.tir.IntImm): + return + size *= dim_len.value + dtype = tvm.DataType(dtype_str) + self.planned_mem_num += 1 + self.planned_alloc_mem += size * ((dtype.bits + 7) // 8) * dtype.lanes + + def _storage_alloc(self, size: relax.Expr) -> None: + assert isinstance(size, relax.ShapeExpr) + self.planned_mem_num += 1 + self.planned_alloc_mem += size.values[0].value diff --git a/python/mlc_chat/compiler/compiler_pass/pipeline.py b/python/mlc_chat/compiler/compiler_pass/pipeline.py index f9bfdd0c59..1f8baab3b6 100644 --- a/python/mlc_chat/compiler/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler/compiler_pass/pipeline.py @@ -7,6 +7,7 @@ from tvm.relax import register_pipeline # pylint: disable=no-name-in-module from .clean_up_tir_attrs import CleanUpTIRAttrs +from .estimate_memory_usage import EstimateMemoryUsage from .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise from .fuse_dequantize_take import FuseDequantizeTake from .fuse_dequantize_transpose import FuseDequantizeTranspose @@ -64,6 +65,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I _LogProgress("Running memory optimizations"), LiftTIRGlobalBufferAlloc(), tvm.tir.transform.ForceNarrowIndexToInt32(), + EstimateMemoryUsage(), ] ) mod = seq(mod._move()) # pylint: disable=protected-access diff --git a/python/mlc_chat/compiler/model/llama_model.py b/python/mlc_chat/compiler/model/llama_model.py index 023db05e82..27d7db0825 100644 --- a/python/mlc_chat/compiler/model/llama_model.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -27,7 +27,7 @@ class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes num_hidden_layers: int rms_norm_eps: float vocab_size: int - position_embedding_base: int = 10000 + position_embedding_base: int = 0 max_sequence_length: int = 0 num_key_value_heads: int = 0 head_dim: int = 0 @@ -49,6 +49,11 @@ def __post_init__(self): "`max_sequence_length` nor `max_position_embeddings` is provided " "in `config.json`." ) + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 if self.num_key_value_heads == 0: self.num_key_value_heads = self.num_attention_heads if self.head_dim == 0: @@ -60,6 +65,69 @@ def __post_init__(self): # pylint: disable=invalid-name,missing-docstring +class RMSNorm(nn.Module): + """ + Module for rms norm layer. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + hidden_size: int, + axes, # pylint: disable=unused-argument + epsilon: float = 1e-5, + bias: bool = True, + dtype: Optional[str] = None, + ): + super().__init__() + self.epsilon = epsilon + self.weight = nn.Parameter((hidden_size,), dtype=dtype) + if bias: + self.bias = nn.Parameter((hidden_size,), dtype=dtype) + else: + self.bias = None + + def forward(self, x: Tensor): + """ + Forward method for rms norm layer. + + Parameters + ---------- + x : Tensor + The input tensor. + + Returns + ------- + ret : Tensor + The output tensor for the rms norm layer. + """ + + def f_square(x): + x = x.astype("float32") + return x * x + + def f_div_mult(x, square_sum, weight, *indices): + *i, k = indices + s = tir.sqrt(square_sum[*i] / x.shape[-1] + self.epsilon) + s = x[*i, k].astype("float32") / s + s = (weight[k] * s).astype(x.dtype) + return s + + def te_op(x: te.Tensor, weight: te.Tensor): + k = te.reduce_axis((0, x.shape[-1]), name="k") + square_sum = te.compute( + x.shape[:-1], + lambda *i: te.sum(f_square(x[*i, k]), axis=k), + name=x.op.name + "red_temp", + ) + return te.compute( + x.shape, + lambda *i: f_div_mult(x, square_sum, weight, *i), + name="rms_norm", + ) + + return op.tensor_expr_op(te_op, "rms_norm", args=[x, self.weight]) + + class RotaryEmbedding(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() @@ -80,9 +148,9 @@ def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var): freq = (offset + s) / freq cos = tir.cos(freq).astype(dtype) * x[b, s, h, d] sin = tir.sin(freq).astype(dtype) * tir.if_then_else( - d < self.head_dim // 2, - -x[b, s, h, d + self.head_dim // 2], - x[b, s, h, d - self.head_dim // 2], + d < head_dim // 2, + -x[b, s, h, d + head_dim // 2], + x[b, s, h, d - head_dim // 2], ) return cos + sin @@ -146,8 +214,8 @@ def forward( # pylint: disable=too-many-locals self.k_cache.append(op.squeeze(k, axis=0)) self.v_cache.append(op.squeeze(v, axis=0)) - k = op.reshape(self.k_cache.view(total_seq_len), (b, t, h_kv, d)) - v = op.reshape(self.v_cache.view(total_seq_len), (b, t, h_kv, d)) + k = op.reshape(self.k_cache.view(t), (b, t, h_kv, d)) + v = op.reshape(self.v_cache.view(t), (b, t, h_kv, d)) if h_kv != h_q: k = k.repeat(h_q // h_kv, axis=2) v = v.repeat(h_q // h_kv, axis=2) @@ -163,11 +231,9 @@ def forward( # pylint: disable=too-many-locals attn_weights = op.softmax(attn_weights, axis=-1) else: attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype) - return self.o_proj( - op.matmul(attn_weights, v) # [b, h, s, t] x [b, h, t, d] = [b, h, s, d] - .permute_dims([0, 2, 1, 3]) # [b, s, h, d] - .reshape((b, s, h_q * d)) - ) + # [b, h, s, t] x [b, h, t, d] => [b, h, s, d] => [b, s, h, d] + output = op.matmul(attn_weights, v) + return self.o_proj(output.permute_dims([0, 2, 1, 3]).reshape((b, s, h_q * d))) class LlamaDecoderLayer(nn.Module): @@ -175,8 +241,8 @@ def __init__(self, config: LlamaConfig, rotary_embedding: RotaryEmbedding): rms_norm_eps = config.rms_norm_eps self.self_attn = LlamaAttention(config, rotary_embedding) self.mlp = LlamaFFN(config) - self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) - self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.input_layernorm = RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): hidden_states = ( @@ -195,7 +261,7 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList( [LlamaDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] ) - self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.norm = RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): hidden_states = self.embed_tokens(inputs) diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py index 935621173b..ecce18d3c0 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -222,8 +222,11 @@ def _quantize( # pylint: disable=too-many-locals max_abs = te.compute( shape=scale_shape, fcompute=lambda i, j: te.max( - te.abs(weight[i, j * self.group_size + r]), - where=j * self.group_size + r < k, + tir.if_then_else( + j * self.group_size + r < k, + te.abs(weight[i, j * self.group_size + r]), + te.min_value(self.model_dtype), + ), axis=r, ), name="max_abs_value", @@ -251,9 +254,13 @@ def _quantize( # pylint: disable=too-many-locals quantized_weight = te.compute( shape=quantized_weight_shape, fcompute=lambda i, j: tir.sum( - scaled_weight[i, j * self.num_elem_per_storage + r] << (r * quantize_dtype.bits), + tir.if_then_else( + j * self.num_elem_per_storage + r < k, + scaled_weight[i, j * self.num_elem_per_storage + r] + << (r * quantize_dtype.bits), + 0, + ), axis=r, - where=j * self.num_elem_per_storage + r < k, ), name="weight", ) diff --git a/python/mlc_chat/compiler/quantization/quantization.py b/python/mlc_chat/compiler/quantization/quantization.py index f84881c966..bae8d07aec 100644 --- a/python/mlc_chat/compiler/quantization/quantization.py +++ b/python/mlc_chat/compiler/quantization/quantization.py @@ -24,6 +24,14 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr """ QUANTIZATION: Dict[str, Quantization] = { + "q3f16_1": GroupQuantize( + name="q3f16_1", + kind="group-quant", + group_size=40, + quantize_dtype="int3", + storage_dtype="uint32", + model_dtype="float16", + ), "q4f16_1": GroupQuantize( name="q4f16_1", kind="group-quant", @@ -32,6 +40,14 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr storage_dtype="uint32", model_dtype="float16", ), + "q4f32_1": GroupQuantize( + name="q4f32_1", + kind="group-quant", + group_size=32, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float32", + ), "q4f16_awq": AWQQuantize( name="q4f16_awq", kind="awq", diff --git a/python/mlc_chat/compiler/quantization/utils.py b/python/mlc_chat/compiler/quantization/utils.py index 3470e42493..9a879d2e96 100644 --- a/python/mlc_chat/compiler/quantization/utils.py +++ b/python/mlc_chat/compiler/quantization/utils.py @@ -18,17 +18,11 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments shape=[weight.shape[0], weight.shape[1] * num_elem_per_storage] if out_shape is None else out_shape, - fcompute=lambda i, j: tir.Cast( - model_dtype, - tir.bitwise_and( - tir.shift_right( - weight[i, j // num_elem_per_storage], - tir.Cast( - storage_dtype, - (j % num_elem_per_storage) * bits, - ), - ), - tir_bin_mask, + fcompute=lambda i, j: tir.bitwise_and( + tir.shift_right( + weight[i, j // num_elem_per_storage], + ((j % num_elem_per_storage) * bits).astype(storage_dtype), ), - ), + tir_bin_mask, + ).astype(model_dtype), ) diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 0000000000..c5e4e0d10a --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,20 @@ +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# Generated by Rust +**/*.rs.bk +/examples/pkg + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +# IDE files +.idea/ +*.iml +.vscode/ diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000000..58cc03f40b --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "mlc-llm" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tvm-rt = { path = "../3rdparty/tvm/rust/tvm-rt", features = ["dynamic-linking"] } +tracing = "0.1.32" +derive_builder = "0.12.0" +serde = { version = "1.0.160", features = ["derive"] } +serde_json = "1.0.107" diff --git a/rust/README.md b/rust/README.md new file mode 100644 index 0000000000..8c92525772 --- /dev/null +++ b/rust/README.md @@ -0,0 +1,25 @@ +# MLC-LLM Rust Package + +This folder contains the source code of MLC-LLM Rust package. + +# Installations +To set up the MLC-LLM Rust package, please follow these steps: + +**Step 1:** Begin by following the detailed installation [instructions](https://llm.mlc.ai/docs/deploy/rest.html#optional-build-from-source) for TVM Unity and MLC-LLM. + +**Step 2:** Define the environment variables for TVM and MLC-LLM by running the following commands in your terminal: +```bash +export TVM_HOME=/path/to/tvm +export MLC_HOME=/path/to/mlc-llm +``` + +**Step 3:** Update your `LD_LIBRARY_PATH` to include the `libtvm_runtime` and `libmlc_llm_module` libraries. These can typically be found within the build directories of your TVM and MLC-LLM installations. + +# How to run it? +To start using the package, you can refer to the example code provided in the examples directory. This code demonstrates how to create a chat_module and serve prompts effectively. + +Execute the example with Cargo using the following command: +```bash +cargo run --example mlc_chat +``` + diff --git a/rust/build.rs b/rust/build.rs new file mode 100644 index 0000000000..d8e01a77e4 --- /dev/null +++ b/rust/build.rs @@ -0,0 +1,4 @@ +fn main() { + println!("cargo:rustc-link-lib=dylib=mlc_llm_module"); + println!("cargo:rustc-link-search=native={}/build", env!("MLC_HOME")); +} diff --git a/rust/examples/mlc_chat.rs b/rust/examples/mlc_chat.rs new file mode 100644 index 0000000000..2e87d56946 --- /dev/null +++ b/rust/examples/mlc_chat.rs @@ -0,0 +1,10 @@ +extern crate mlc_llm; + +use mlc_llm::chat_module::ChatModule; + +fn main() { + let cm = ChatModule::new("/path/to/Llama2-13B-q8f16_1", "rocm", None).unwrap(); + let output = cm.generate("what is the meaning of life?", None).unwrap(); + println!("resp: {:?}", output); + println!("stats: {:?}", cm.stats(false)); +} diff --git a/rust/src/chat_module.rs b/rust/src/chat_module.rs new file mode 100644 index 0000000000..831905eee8 --- /dev/null +++ b/rust/src/chat_module.rs @@ -0,0 +1,445 @@ +use std::fs; +use std::path::{Path, PathBuf}; +use std::result; +use tracing::info; +use tvm_rt::{function::Function, Module}; + +use super::config::*; + +#[derive(Debug)] +pub enum ChatModuleError { + /// Global function in a TVM Module is not found + GlobalFuncNotFound, + /// TVM Runtime error + TvmRuntime(tvm_rt::Error), +} + +impl From for ChatModuleError { + fn from(e: tvm_rt::Error) -> Self { + Self::TvmRuntime(e) + } +} + +pub type Result = result::Result; + +/// The ChatModule for MLC LLM. +/// +/// # Examples +/// +/// ``` +/// use mlc_llm::chat_module::ChatModule; +/// +/// // Create a ChatModule instance +/// let cm = ChatModule::new("Llama-2-7b-chat-hf-q4f16_1", "cuda", None, None).unwrap(); +/// +/// // Generate a response for a given prompt +/// let output = cm.generate("What is the meaning of life?", None).unwrap(); +/// +/// // Print prefill and decode performance statistics +/// println!("Statistics: {:?}\n", cm.stats(false).unwrap()); +/// +/// let output = cm.generate("What is Rust?", None).unwrap(); +/// ``` +pub struct ChatModule { + chat_module: Module, + chat_config: ChatConfig, +} + +#[derive(Debug, Copy, Clone)] +pub enum PlaceInPrompt { + All = 0, + Begin = 1, + Middle = 2, + End = 3, +} + +impl PlaceInPrompt { + pub fn to_value(&self) -> i32 { + *self as i32 + } +} + +/// Parse the input device identifier into device name and id. +/// +/// # Parameters +/// * `device` - The device identifier to parse. It can be in the format "device_name" (e.g., "cuda") +/// or "device_name:device_id" (e.g., "cuda:1"). +/// +/// # Returns +/// * `device_name` - The name of the device. +/// * `device_id` - The id of the device, or 0 if not specified in the input. +fn parse_device_str(device: &str) -> (&str, i32) { + let device_err_msg = format!( + "Invalid device name: {}. Please enter the device in the form \ + 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ + one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'.", + device + ); + let device_args: Vec<&str> = device.split(':').collect(); + match device_args.len() { + 1 => (device_args[0], 0), + 2 => (device_args[0], device_args[1].parse::().unwrap()), + _ => panic!("{}", device_err_msg), + } +} + +/// Use user-provided argument `model` to search for a valid model path. +/// We define "valid" as having an `mlc-chat-config.json` right under the folder. +/// +/// # Parameters +/// * `model`: User's input; may be a compiled model's name, or a full path. +/// +/// # Returns +/// * `model_path`: A "valid" path to model folder with `mlc-chat-config.json` existing under it. +/// * `chat_file`: The path to the `mlc-chat-config.json` file. +/// +/// # Panics +/// * If a valid model_path cannot be found. +pub fn get_model_path(model: &str) -> (PathBuf, PathBuf) { + // Note that the order of this list corresponds to our search priority + let candidate_paths = vec![ + PathBuf::from(model), // full path, or just the name + PathBuf::from(format!("{}/params", model)), // Default directory after mlc_llm.build_model() + PathBuf::from(format!("dist/prebuilt/{}", model)), // Using prebuilt workflow + PathBuf::from(format!("dist/{}/params", model)), // Default directory after mlc_llm.build_model() in the current path + PathBuf::from(format!("dist/prebuilt/mlc-chat-{}", model)), // Also prebuilt workflow, but missed prefix + ]; + + // Look for the first folder that has `mlc-chat-config.json` under it + for candidate in &candidate_paths { + let chat_file = candidate.join("mlc-chat-config.json"); + if chat_file.is_file() { + info!( + "Using model folder: {:?}", + candidate.canonicalize().unwrap() + ); + info!( + "Using mlc chat config: {:?}", + chat_file.canonicalize().unwrap() + ); + return (candidate.clone(), chat_file); + } + } + + let mut found_folder = false; + let mut valid_dir_str = String::new(); + for candidate in &candidate_paths { + if candidate.is_dir() { + valid_dir_str += &format!("- {:?}\n", candidate.canonicalize().unwrap()); + found_folder = true; + } + } + + if found_folder { + // Error 1: there is a folder, but not an mlc-llm model folder (E1) + let err_msg = format!( + "The model folder provided does not seem to refer to a valid mlc-llm model folder.\n\ + Specifically, we cannot find `mlc-chat-config.json`, a required file. You should \ + provide a path that contains the file.\n\ + According to your input `model`, we looked at folder(s):\n\ + {}\n\ + MLC-Chat consumes models that are processed by the MLC-LLM build process.\n\ + ", + valid_dir_str, + ); + panic!("{}", err_msg); + } else { + // Error 2: cannot find a folder (E0) + let all_paths_str = candidate_paths + .iter() + .map(|path| format!("- {}\n", path.display())) + .collect::(); + let err_msg = format!( + "Cannot find the model folder. We searched over the following possible paths:\n\ + {}\n\ + You can try to pass in `model=/path/to/your-model-path`, and confirm \ + that it contains `mlc-chat-config.json`, among other essential files.\n\ + ", + all_paths_str, + ); + panic!("{}", err_msg); + } +} + +/// Read in the config file in model path, then potentially override with user input. +/// +/// # Parameters: +/// * `config_file_path`: &Path +/// `chat_file` returned by a function like `get_model_path()`. +fn get_chat_config( + config_file_path: &Path, +) -> result::Result> { + // Read the base configuration from the file + let file_contents = fs::read_to_string(config_file_path)?; + let final_chat_config = ChatConfig::from_json(&file_contents)?; + Ok(final_chat_config) +} + +fn get_lib_module_path( + model: &str, + model_path: &Path, + chat_config: &ChatConfig, + model_lib_path: Option<&str>, + device_name: &str, + config_file_path: &Path, +) -> PathBuf { + // 1. Use user's model_lib_path if provided + if let Some(lib_path) = model_lib_path { + let path = Path::new(lib_path); + if path.is_file() { + info!("Using library model: {:?}", path); + return path.to_path_buf(); + } else { + panic!( + "The `model_lib_path` you passed in is not a file: {:?}.", + lib_path + ); + } + } + + // 2. Generate all possible file names according to OS + let mut candidate_paths = Vec::new(); + if let Some(model_lib) = &chat_config.model_lib { + let candidate_lib_names: Vec = if cfg!(target_os = "linux") { + vec![format!("{}-{}.so", model_lib, device_name)] + } else if cfg!(target_os = "macos") { + vec![ + format!("{}-{}.dylib", model_lib, device_name), + format!("{}-{}.so", model_lib, device_name), + ] + } else if cfg!(target_os = "windows") { + vec![format!("{}-{}.dll", model_lib, device_name)] + } else { + vec![ + format!("{}-{}.dylib", model_lib, device_name), + format!("{}-{}.so", model_lib, device_name), + format!("{}-{}.dll", model_lib, device_name), + ] + }; + + // 3. Generate possible model library paths + let pardir_model_path = model_path.parent().unwrap(); + for lib_name in &candidate_lib_names { + let paths: Vec = vec![ + lib_name.clone(), + format!("dist/prebuilt/lib/{}", lib_name), + format!("dist/{}/{}", model, lib_name), + model_path.join(lib_name).to_string_lossy().into_owned(), + pardir_model_path + .join(lib_name) + .to_string_lossy() + .into_owned(), + ]; + + candidate_paths.extend(paths); + } + + // 4. Search for model library + for candidate in &candidate_paths { + let candidate_path = Path::new(candidate); + if candidate_path.is_file() { + info!("Using library model: {:?}", candidate_path); + return candidate_path.to_path_buf(); + } + } + + // 5. Error + let mut err_msg = format!( + "Cannot find the model library that corresponds to `{:?}`.\n\ + `{:?}` is either provided in the `chat_config` \ + you passed in, or specified in {:?}.\n\ + We searched over the following possible paths: \n", + model_lib, model_lib, config_file_path + ); + for candidate in &candidate_paths { + err_msg += &format!("- {}\n", candidate); + } + err_msg += &format!( + "If you would like to directly specify the model library path, you may \ + consider passing in the `ChatModule.model_lib_path` parameter." + ); + + panic!("{}", err_msg); + } else { + panic!("Cannot find the model library, you need to either pass it in, or specify in the chat_config file."); + } +} + +impl ChatModule { + pub fn new(model: &str, device: &str, model_lib_path: Option<&str>) -> Result { + let device_err_msg = format!( + "Invalid device name: {}. Please enter the device in the form \ + 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ + one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'.", + device + ); + + let (device_name, device_id) = parse_device_str(device); + + // 1. Get device name and id + let device_type = match device_name { + "cude" => 2, + "opencl" => 4, + "vulkan" => 7, + "metal" => 8, + "rocm" => 10, + _ => panic!("{}", device_err_msg), + }; + + static GLOBAL_FUNC_NAME: &str = "mlc.llm_chat_create"; + let f = Function::get(GLOBAL_FUNC_NAME).ok_or(ChatModuleError::GlobalFuncNotFound)?; + let m: Module = f + .invoke(vec![device_type.into(), device_id.into()]) + .unwrap() + .try_into() + .expect("call should succeed"); + + // 2. Look up the model path + let (model_path, config_file_path) = get_model_path(model); + + // 3. Instantiate chat_config + let chat_config = get_chat_config(&config_file_path).unwrap(); + + // 4. Look up the model library + let model_lib_path = get_lib_module_path( + model, + &model_path, + &chat_config, + model_lib_path, + device_name, + &config_file_path, + ); + + let chat_mod = Self { + chat_module: m, + chat_config: chat_config, + }; + let model_lib_str = model_lib_path.as_path().display().to_string(); + let model_path_str = model_path.as_path().display().to_string(); + chat_mod + .reload(&model_lib_str, &model_path_str, "") + .unwrap(); + Ok(chat_mod) + } + + /// Reload the chat module from the given library and model path. + fn reload(&self, lib: &str, model_path: &str, app_config_json: &str) -> Result<()> { + let f = self.chat_module.get_function("reload", false)?; + f.invoke(vec![lib.into(), model_path.into(), app_config_json.into()])?; + Ok(()) + } + + /// Reset the chat session, clear all chat history, and potentially + /// override the original `mlc-chat-config.json`. + pub fn reset_chat(&self) -> Result<()> { + // TODO: add optional user-specified ChatConfig + let f = self.chat_module.get_function("reset_chat", false)?; + f.invoke(vec![])?; + Ok(()) + } + + /// Get the runtime stats of the encoding step, decoding step (and embedding step if exists) + /// of the chat module in text form. + pub fn stats(&self, verbose: bool) -> Result { + if verbose { + let f = self + .chat_module + .get_function("verbose_runtime_stats_text", false)?; + let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); + return Ok(res); + } + let f = self.chat_module.get_function("runtime_stats_text", false)?; + let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); + return Ok(res); + } + + /// Check if the stop condition is met for the current round. + fn stopped(&self) -> Result { + let f = self.chat_module.get_function("stopped", false)?; + let res: bool = f.invoke(vec![])?.try_into().expect("call should succeed"); + Ok(res) + } + + /// Get the output message in the current round. + fn get_message(&self) -> Result { + let f = self.chat_module.get_function("get_message", false)?; + let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); + Ok(res) + } + + /// Decode the next token, the decoding result is stored in a buffer and + /// can be retrieved by [get_message]. + fn decode(&self, generation_config: Option<&GenerationConfig>) -> Result<()> { + let generation_config_str = match generation_config { + Some(config) => serde_json::to_string(config).unwrap(), + None => { + let config = GenerationConfig::from_chat_config(&self.chat_config); + serde_json::to_string(&config).unwrap() + } + }; + let f = self.chat_module.get_function("decode", false)?; + f.invoke(vec![generation_config_str.into()])?; + Ok(()) + } + + /// A high-level method that returns the full response from the chat module given a user + /// prompt. User can optionally specify which callback method to use upon receiving the + /// response. + pub fn generate( + &self, + prompt: &str, + generation_config: Option<&GenerationConfig>, + ) -> Result> { + // TODO: add progress_callback + let mut new_msgs: Vec = vec![]; + let mut num_return_sequences: usize = 1; + + if let Some(gc) = generation_config { + if let Some(n) = gc.n { + num_return_sequences = n; + } + } + + for _ in 0..num_return_sequences { + self.reset_chat().unwrap(); + self.prefill(prompt, true, PlaceInPrompt::All, generation_config) + .unwrap(); + + while !self.stopped().unwrap() { + self.decode(generation_config)?; + } + let new_msg = self.get_message().unwrap(); + new_msgs.push(new_msg); + } + + Ok(new_msgs) + } + + /// Run prefill stage for a given input and optionally decode the first output token. + /// User can decide where to place the input in the prompt. + fn prefill( + &self, + input: &str, + decode_next_token: bool, + place_in_promt: PlaceInPrompt, + generation_config: Option<&GenerationConfig>, + ) -> Result<()> { + let generation_config_str = match generation_config { + Some(config) => serde_json::to_string(config).unwrap(), + None => { + let config = GenerationConfig::from_chat_config(&self.chat_config); + serde_json::to_string(&config).unwrap() + } + }; + + let f = self.chat_module.get_function("prefill", false)?; + f.invoke(vec![ + input.into(), + (&decode_next_token).into(), + place_in_promt.to_value().into(), + generation_config_str.into(), + ])?; + Ok(()) + } +} + diff --git a/rust/src/config.rs b/rust/src/config.rs new file mode 100644 index 0000000000..61371d197a --- /dev/null +++ b/rust/src/config.rs @@ -0,0 +1,276 @@ +use serde::{Deserialize, Serialize}; + +/// A struct that represents user-defined partial configuration for conversation template. +/// +/// This can be passed in to the instantiation of a [ChatModule](crate::chat_module::ChatModule) +/// instance to override the default setting in `mlc-chat-config.json` under the +/// model folder. Note that we will first load the predefined template +/// with the name specified in `conv_template`. +/// +/// Since the configuration is partial, everything will be optional. +#[derive(Clone, Default, Builder, Debug, Serialize, Deserialize)] +#[builder(default)] +pub struct ConvConfig { + /// Name of the conversation. + name: Option, + + /// The prompt encoded before starting the chat. + system: Option, + + /// An array that describes the role names of the user and the model. + roles: Option>, + + /// The chat history represented as an array of string pairs. + messages: Option>>, + + /// The offset used to begin the chat from the chat history. + offset: Option, + + /// Specifies whether we are in chat-bot mode (`0`) or pure LM prompt mode (`1`). + separator_style: Option, + + /// An array of strings indicating the separators to be used after a user message and a model message respectively. + seps: Option>, + + /// A string indicating the separator between a role and a message. + role_msg_sep: Option, + + /// A string indicating the separator to append to a role when there is no message yet. + role_empty_sep: Option, + + /// When the `stop_str` is encountered, the model will stop generating output. + stop_str: Option, + + /// A list of token IDs that act as stop tokens. + stop_tokens: Option>, + + /// Determines whether a beginning-of-string (bos) token should be added before the input tokens. + add_bos: Option, +} + +impl ConvConfig { + pub fn post_init(&mut self) { + if let Some(messages) = &self.messages { + if self.offset.is_none() { + self.offset = Some(messages.len()); + } + } + } +} + +/// A struct that represents user-defined partial configuration for the chat config file. +/// +/// An instance of [ChatConfig] can be passed in to override the default setting. +/// Since the configuration is partial, everything will be optional. +/// +/// Note: This struct is used to represent the chat config during intermediate processing. +#[derive(Builder, Debug, Default, Serialize, Deserialize)] +#[builder(default)] +pub struct ChatConfig { + /// The necessary model library to launch this model architecture. + /// Recommended to reuse model library when possible. + pub model_lib: Option, + + /// Uniquely identifying the model in application. Also used by + /// CLI to specify which model to run. + pub local_id: Option, + + /// The name of the conversation template that this chat uses. + pub conv_template: Option, + + /// Temperature applied to logits before sampling. Encourages diverse outputs if higher. + pub temperature: Option, + + /// Controls the likelihood of the model generating repeated texts. + /// See the CTRL paper for more details: + repetition_penalty: Option, + + /// Determines the set of tokens from which we sample during decoding. + /// More info on top-p sampling: + top_p: Option, + + /// Approximated average number of generated tokens in each round. + mean_gen_len: Option, + + /// Maximum number of tokens to be generated in each round. + max_gen_len: Option, + + /// Fraction of maximum window size to shift when it is exceeded. + shift_fill_factor: Option, + + /// List of tokenizer files of the model. + tokenizer_files: Option>, + + /// Partial overriding configuration for conversation template. + conv_config: Option, + + /// The category of the model's architecture (e.g. `llama`, `gpt_neox`, `rwkv`). + model_category: Option, + + /// Name of the model (e.g. `Llama-2-7b-chat-hf`). + model_name: Option, + + /// Tensor parallel degree. + num_shards: Option, + + /// Maximum kv cache window size. + max_window_size: Option, +} + +impl ChatConfig { + pub fn from_json(json_str: &str) -> Result { + serde_json::from_str(json_str) + } +} + +/// A struct that represents user-defined generation configuration. +/// +/// An instance of [GenerationConfig] can be passed into the +/// [ChatModule::generate](crate::chat_module::ChatModule::generate) function +/// to override the default generation settings specified in `mlc-chat-config.json` +/// and `ChatConfig` under the model folder. +/// +/// Once the generation ends, `GenerationConfig` is discarded, as the values +/// are only intended to override the `ChatConfig` generation settings during a +/// single generation, unless it is recurrently passed to the `generate` function. +/// This allows for changing generation settings over time, without permanently +/// overriding the `ChatConfig`. +/// +/// Since the configuration is partial, all fields are optional. +#[derive(Builder, Debug, Default, Serialize, Deserialize)] +#[builder(default)] +pub struct GenerationConfig { + /// The temperature applied to logits before sampling. The default value is + /// `0.7`. A higher temperature encourages more diverse outputs, while a + /// lower temperature produces more deterministic outputs. + temperature: Option, + + /// The repetition penalty controls the likelihood of the model generating + /// repeated texts. The default value is set to `1.0`, indicating that no + /// repetition penalty is applied. Increasing the value reduces the + /// likelihood of repeat text generation. However, setting a high + /// `repetition_penalty` may result in the model generating meaningless + /// texts. The ideal choice of repetition penalty may vary among models. Only + /// Active when presence_penalty and frequency_penalty are both `0.0`. + + /// For more details on how repetition penalty controls text generation, please + /// check out the CTRL paper . + repetition_penalty: Option, + + /// This parameter determines the set of tokens from which we sample during + /// decoding. The default value is set to `0.95`. At each step, we select + /// tokens from the minimal set that has a cumulative probability exceeding + /// the ``top_p` parameter. + + /// For additional information on top-p sampling, please refer to this blog + /// post: . + top_p: Option, + + /// The approximated average number of generated tokens in each round. Used + /// to determine whether the maximum window size would be exceeded. + mean_gen_len: Option, + + /// This parameter determines the maximum length of the generated text. If it is + /// not set, the model will generate text until it encounters a stop token. + max_gen_len: Option, + + /// Number between `-2.0` and `2.0`. Positive values penalize new tokens based on + /// whether they appear in the text so far, increasing the model's likelihood + /// to talk about new topics. Negative values can increase the likelihood of + /// repetition. + presence_penalty: Option, + + /// Number between `-2.0` and `2.0`. Positive values penalize new tokens based on their + /// existing frequency in the text so far, decreasing the model's likelihood to + /// repeat the same line verbatim. Negative values can increase the likelihood of + /// repetition. + frequency_penalty: Option, + + /// This parameter determines the number of text samples to generate. The default + /// value is `1`. Note that this parameter is only used when `stream` is set to + /// `false`. + pub n: Option, + + /// When `stop` is encountered, the model will stop generating output. + /// It can be a string or a list of strings. If it is a list of strings, the model + /// will stop generating output when any of the strings in the list is encountered. + /// Note that this parameter does not override the default stop string of the model. + stop: Option>, +} + +impl GenerationConfig { + pub fn from_chat_config(chat_config: &ChatConfig) -> Self { + Self { + temperature: chat_config.temperature, + repetition_penalty: chat_config.repetition_penalty, + top_p: chat_config.top_p, + mean_gen_len: chat_config.mean_gen_len, + max_gen_len: chat_config.max_gen_len, + presence_penalty: Some(0.0), + frequency_penalty: Some(0.0), + n: Some(0), + stop: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_conv_config() { + let mut config = ConvConfig { + messages: Some(vec![vec![ + "User: Hi".to_string(), + "Assistant: Hello".to_string(), + ]]), + offset: None, + ..Default::default() + }; + config.post_init(); + assert_eq!(config.offset, Some(1)); + } + + #[test] + fn test_chat_config() { + let json_data = r#" + { + "model_lib": "some_lib", + "local_id": "id123", + "temperature": 0.7 + } + "#; + + let config = ChatConfig::from_json(json_data).unwrap(); + + assert_eq!(config.model_lib, Some("some_lib".to_string())); + assert_eq!(config.local_id, Some("id123".to_string())); + assert_eq!(config.temperature, Some(0.7)); + let _pretty_json = serde_json::to_string_pretty(&config).unwrap(); + } + + #[test] + fn test_generation_config() { + let chat_config = ChatConfigBuilder::default() + .temperature(Some(0.7)) + .top_p(Some(0.8)) + .mean_gen_len(Some(50)) + .max_gen_len(Some(75)) + .build() + .unwrap(); + + let gen_config = GenerationConfig::from_chat_config(&chat_config); + + assert_eq!(gen_config.temperature, chat_config.temperature); + assert_eq!( + gen_config.repetition_penalty, + chat_config.repetition_penalty + ); + assert_eq!(gen_config.top_p, chat_config.top_p); + assert_eq!(gen_config.mean_gen_len, chat_config.mean_gen_len); + assert_eq!(gen_config.max_gen_len, chat_config.max_gen_len); + assert_eq!(gen_config.presence_penalty, Some(0.0)); + assert_eq!(gen_config.frequency_penalty, Some(0.0)); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 0000000000..e83534ceeb --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,5 @@ +#[macro_use] +extern crate derive_builder; + +pub mod chat_module; +pub mod config; diff --git a/tests/legacy-python/module_intercept.py b/tests/legacy-python/module_intercept.py new file mode 100644 index 0000000000..e63bb21de6 --- /dev/null +++ b/tests/legacy-python/module_intercept.py @@ -0,0 +1,147 @@ +"""This script is an example of running and comparing the outputs of two different TVM Relax VMs. +""" +# pylint: disable=missing-docstring,invalid-name +import json + +import numpy as np +import torch +import tvm +from transformers import LlamaTokenizer +from tvm import relax +from tvm.contrib import tvmjs + +KVCACHE_FUNCS = [ + "vm.builtin.attention_kv_cache_append", + "vm.builtin.attention_kv_cache_view", +] +DEVICE = "cuda:0" +PROMPT = "What is the meaning of life?" +TOKENIZER = "./dist/debug-llama/" + +COMBO = { + "CURRENT": { + "model_lib": "./dist/debug-llama/llama.so", + "params": "./dist/debug-llama", + "target_func": "fused_fused_dequantize1_NT_matmul6", + }, + "LEGACY": { + "model_lib": "./dist/Llama-2-7b-chat-hf-q4f16_1/Llama-2-7b-chat-hf-q4f16_1-cuda.so", + "params": "./dist/Llama-2-7b-chat-hf-q4f16_1/params", + "target_func": "fused_fused_decode2_NT_matmul", + }, +} + + +class Instrument: # pylint: disable=too-few-public-methods + def __init__( + self, + target_func: str, + ): + self.first_time = True + self.target_func = target_func + self.saved_args = [] # type: ignore + + def __call__( + self, + func, + func_symbol: str, + before_run: bool, + ret_value, + *args, + ): + if before_run: + return + if func_symbol.startswith("vm.builtin."): + if func_symbol not in KVCACHE_FUNCS: + return + if func_symbol == self.target_func and self.first_time: + self.first_time = False + for arg in args: + print(arg.shape, arg.dtype) + self.saved_args.append(arg.numpy()) + + +class TestState: + def __init__(self, device, model_lib, target_func): + self.mod = relax.VirtualMachine( + tvm.runtime.load_module(model_lib), + device, + ) + self.inst = Instrument(target_func=target_func) + self.mod.set_instrument(self.inst) + + +def _tokenize(sentence: str): + tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER, trust_remote_code=True) + tokens = tokenizer(PROMPT, return_tensors="pt").input_ids.to(torch.int32).numpy() + print(f"Tokenizing: {sentence}") + print(f"Tokens: {tokens}") + return tokens + + +def _load_params(params, device, metadata): + param_dict, _ = tvmjs.load_ndarray_cache(params, device) + param_list = [] + for name in [x["name"] for x in metadata["params"]]: + param_list.append(param_dict[name]) + return param_list + + +def _load_params_legacy(params, device): + param_dict, metadata = tvmjs.load_ndarray_cache(params, device) + param_list = [] + for i in range(metadata["ParamSize"]): + param_list.append(param_dict[f"param_{i}"]) + return param_list + + +def _as_input_tuple(scalar): + return tvm.runtime.ShapeTuple([scalar]) + + +@tvm.register_func("debug_save") +def _debug_save(x, _): + return tvm.nd.array(x.numpy(), x.device) + + +def main() -> None: + device = tvm.device(DEVICE) + prompt = _tokenize(PROMPT) + + def _run_legacy(model_lib, params, target_func): + state = TestState(device, model_lib, target_func) + kv_cache = state.mod["create_kv_cache"]() + param_list = _load_params_legacy(params, device) + state.mod["prefill"]( + tvm.nd.array(prompt, device), + _as_input_tuple(len(prompt[0])), + kv_cache, + param_list, + ) + return state.inst.saved_args + + def _run_current(model_lib, params, target_func): + state = TestState(device, model_lib, target_func) + metadata = json.loads(state.mod["_metadata"]()) + kv_cache = state.mod["_initialize_effect"]() + param_list = _load_params(params, device, metadata) + state.mod["prefill"]( + tvm.nd.array(prompt, device), + _as_input_tuple(len(prompt[0])), + kv_cache, + param_list, + ) + return state.inst.saved_args + + print("============== Running old flow =================") + new_args = _run_current(**COMBO["CURRENT"]) + print("============== Running new flow =================") + old_args = _run_legacy(**COMBO["LEGACY"]) + + for i, (new_arg, old_arg) in enumerate(zip(new_args, old_args)): + print(f"Checking arg {i}") + np.testing.assert_allclose(new_arg, old_arg, rtol=1e-12, atol=1e-12) + + +if __name__ == "__main__": + main() diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py index 106d0f5fb5..04b23e91d3 100644 --- a/tests/python/quantization/test_group_quantization.py +++ b/tests/python/quantization/test_group_quantization.py @@ -35,8 +35,10 @@ def quantize_np(config: GroupQuantize, weight: np.ndarray): 0, config.max_int_value * 2, ).astype(config.storage_dtype) + weight_filtered = np.reshape(weight_scaled_reshaped, (n, k)) + weight_filtered[..., weight.shape[1] :] = 0 weight_scaled = np.reshape( - weight_scaled_reshaped, (n, k // config.num_elem_per_storage, config.num_elem_per_storage) + weight_filtered, (n, k // config.num_elem_per_storage, config.num_elem_per_storage) ) indice_k = np.indices(weight_scaled.shape, dtype=config.storage_dtype)[-1] quantized_weight = np.sum( @@ -53,6 +55,7 @@ def dequantize_np( scale: np.ndarray, out_shape: List[int] = None, ): + assert weight.shape[0] == scale.shape[0] bin_mask = (1 << DataType(config.quantize_dtype).bits) - 1 max_int = config.max_int_value out_shape = ( @@ -70,13 +73,21 @@ def dequantize_np( ), bin_mask, ) - return ((weight_bin - max_int) * scale_repeated)[: out_shape[0]][: out_shape[1]] + assert weight_bin.shape[1] <= scale_repeated.shape[1] + return ((weight_bin - max_int) * scale_repeated[..., : weight_bin.shape[1]])[ + : out_shape[0], : out_shape[1] + ] @pytest.mark.parametrize( "quant_name, shape, dtype, device", [ + ("q3f16_1", [2, 13], "float16", "cpu"), + ("q3f16_1", [16, 120], "float16", "cpu"), + ("q4f16_1", [2, 13], "float16", "cpu"), ("q4f16_1", [16, 128], "float16", "cpu"), + ("q4f32_1", [2, 13], "float32", "cpu"), + ("q4f32_1", [16, 128], "float32", "cpu"), ], ) def test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: str): @@ -90,15 +101,20 @@ def test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: tvm.testing.assert_allclose( dequantize_np(config, quantized_weight, scale, shape), dequantize_np(config, quantized_weight_ref, scale_ref, shape), - rtol=1e-3, - atol=0.2, + rtol=1e-2 if quant_name.startswith("q3") else 1e-3, + atol=0.4 if quant_name.startswith("q3") else 0.2, ) @pytest.mark.parametrize( "quant_name, shape, dtype", [ + ("q3f16_1", [2, 13], "float16"), + ("q3f16_1", [16, 120], "float16"), + ("q4f16_1", [2, 13], "float16"), ("q4f16_1", [16, 128], "float16"), + ("q4f32_1", [2, 13], "float32"), + ("q4f32_1", [16, 128], "float32"), ], ) def test_dequantize_weight(quant_name: str, shape: List[int], dtype: str): @@ -115,9 +131,9 @@ def forward(self, x: nn.Tensor): weight_np = np.random.randint( np.iinfo(config.storage_dtype).min, np.iinfo(config.storage_dtype).max, - (shape[0], shape[1] // config.num_elem_per_storage), + (shape[0], -(shape[1] // -config.num_elem_per_storage)), ).astype(config.storage_dtype) - scale_np = np.random.random((shape[0], shape[1] // config.group_size)).astype( + scale_np = np.random.random((shape[0], -(shape[1] // -config.group_size))).astype( config.model_dtype ) mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") @@ -127,14 +143,16 @@ def forward(self, x: nn.Tensor): out = model["forward"]( torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype))) # pylint: disable=no-member ) - ref = dequantize_np(config, weight_np, scale_np).T + ref = dequantize_np(config, weight_np, scale_np, shape).T tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize( "quant_name, shape, dtype", [ + ("q3f16_1", [16, 128], "float16"), ("q4f16_1", [16, 128], "float16"), + ("q4f32_1", [16, 128], "float32"), ], ) def test_quantize_model(quant_name: str, shape: List[int], dtype: str):