Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge upstream nov11 #59

Merged
merged 117 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
898db76
[API] Add GenerationConfig (#1024)
davidpissarra Oct 8, 2023
ad3a6b9
Fix two bugs in kv-cache backtrack loop (#856)
shenberg Oct 8, 2023
6e40c21
[Build] Added --pdb flag to build.py, drop into pdb on error (#1017)
Lunderberg Oct 8, 2023
bae37b3
[Android] Use `AlertDialog` instead of `Toast` (#1039)
cyx-6 Oct 8, 2023
b44f679
Add doc for ChatConfig, ConvConfig, GenerationConfig, BuildArgs (#1040)
CharlieFRuan Oct 9, 2023
3a9849a
[Android] Add Llama2 q4f16_0 (#1041)
spectrometerHBH Oct 9, 2023
bed9e60
[Docs] Model prebuilts tracking page revamp (#1000)
CharlieFRuan Oct 9, 2023
c02fdaf
Update compile_models.rst (#1038)
yongjer Oct 9, 2023
85001ed
Support for the Stable LM 3B model (#1008)
jeethu Oct 9, 2023
a032d40
[Docs] Iterate model prebuilts docs (#1043)
CharlieFRuan Oct 9, 2023
a58605f
Update README.md
junrushao Oct 9, 2023
bdd9d9b
[CPP] Separate common utils out from llm_chat.cc (#1044)
MasterJH5574 Oct 9, 2023
20131fb
Update README.md (#1045)
junrushao Oct 9, 2023
1e6fb11
add verbose stats to mlc-chat REST API (#1049)
denise-k Oct 11, 2023
b9179cf
[Transform] Apply split_rotary optimization on prefill (#1033)
Lunderberg Oct 12, 2023
98ebd28
[Docs] Add `mlc.ai/package` to `DEPENDENCY INSTALLATION` group (#1055)
LeshengJin Oct 12, 2023
bfaa5b9
Revert "[Transform] Apply split_rotary optimization on prefill (#1033…
MasterJH5574 Oct 12, 2023
ca8c11b
[BugFix] Set the right `max_sequence_length` for both Llama-1 and Lla…
sunggg Oct 13, 2023
edab9b5
[Doc] Use -U instead of --force-reinstall (#1062)
junrushao Oct 13, 2023
d854105
[Model] Initial batching support for Llama (#1048)
MasterJH5574 Oct 14, 2023
c2b8cbc
Fix Stable LM 3B build (#1061)
jeethu Oct 14, 2023
481cd92
[Core] Remove duplication in MODEL.get_model calls (#1054)
Lunderberg Oct 14, 2023
8184431
[ParamManager] Cleanup creation of quantization IRModule (#1053)
Lunderberg Oct 14, 2023
9010d48
Minor typo fix (#1064)
jeethu Oct 15, 2023
b0bfc88
Add links to Python API Reference (#1068)
junrushao Oct 15, 2023
204860b
[Fix] ChatModule incorrect temperature buffer shape (#1070)
MasterJH5574 Oct 15, 2023
d202077
[ParamManager] Added progress bar for get_item/set_item (#1063)
Lunderberg Oct 16, 2023
9872c48
[Python] Extract common device str parse function in ChatModule (#1074)
MasterJH5574 Oct 16, 2023
3aefd9f
[Bugfix] Compilation Error in q4f32_1 (#1078)
junrushao Oct 17, 2023
2625945
Establish `mlc_chat.compiler` (#1082)
junrushao Oct 19, 2023
56a8004
Update README.md for Multi-GPU (#1090)
junrushao Oct 19, 2023
b0373d1
Support lib_path override in C++. Improvements on docs and error mess…
rickzx Oct 19, 2023
830656f
StreamIterator (#1057)
varshith15 Oct 19, 2023
9bf5723
Update `benchmark.py` according to #1086 (#1091)
junrushao Oct 19, 2023
62d0c03
Disable Disco for q4f16_ft and q8f16_ft quantization (#1094)
LeshengJin Oct 20, 2023
cf39bf6
[Format] Apply isort and black for `python/` (#1097)
junrushao Oct 20, 2023
e9b85ce
More formatting (#1099)
junrushao Oct 21, 2023
03c641a
Enable Python Linter (#1098)
junrushao Oct 21, 2023
46d11e6
Add Basic Pylint and Mypy Tooling (#1100)
junrushao Oct 21, 2023
6159cc4
[CI] Add clang-format (#1103)
junrushao Oct 22, 2023
16dd2ae
[Slim-LM] Smart path finding for config and weight (#1088)
LeshengJin Oct 23, 2023
f57c9c9
[Transform] Provide IRModule transform for rewrite_attention (#1052)
Lunderberg Oct 23, 2023
e5927ce
[ParamManager] Use BundleModelParams for transform_dequantize (#1056)
Lunderberg Oct 23, 2023
7ae8c6d
[Slim-LM] Introduce HFLoad for loading Pytorch and SafeTensor weights…
LeshengJin Oct 23, 2023
5a7dcd8
[WINDOWS] reduce noise in windows build (#1115)
tqchen Oct 24, 2023
61179a0
Add CLI commands for compilation (#1109)
junrushao Oct 24, 2023
8ce7793
Auto updated submodule references
Oct 24, 2023
488017d
fix mismatched argument name (#1117)
Sing-Li Oct 24, 2023
206103b
[Docs] Add doc for max and mean gen len, shift factor; and buildArgs …
CharlieFRuan Oct 24, 2023
2aa6809
Revert "[ParamManager] Use BundleModelParams for transform_dequantize…
junrushao Oct 24, 2023
9cb8e8e
Remove inaccurate warning message (#1121)
junrushao Oct 24, 2023
9166edb
[REST] OpenAI compatible Rest API (#1107)
Kartik14 Oct 24, 2023
a4279e3
Add --opt flag parsing to CLI (#1123)
junrushao Oct 25, 2023
973f9fc
[ParamManager][Redo] Use BundleModelParams for transform_dequantize (…
Lunderberg Oct 25, 2023
24f795e
added details to windows installation (#1133)
goutham2688 Oct 27, 2023
2c492e5
Grammatical and Typographical improvements (#1139)
tmsagarofficial Oct 28, 2023
2ec0cc8
Minor enhancements to `ChatModule` (#1132)
YuchenJin Oct 28, 2023
27ac5ac
Updating tvm install docs (#1143)
David-Sharma Oct 29, 2023
2b6d832
Make the help info consistent with program name (#1137)
fennecJ Oct 29, 2023
878ae84
Support parameter packing (#1146)
junrushao Oct 29, 2023
c0c3a8d
[Slim-LM] Enable Group Quant (#1129)
zxybazh Oct 29, 2023
2193767
Enable Mypy and Pylint in mlc_chat Python Package (#1149)
junrushao Oct 29, 2023
0a25374
Migrate Compiler Passes (#1150)
junrushao Oct 30, 2023
1a79a53
Compile Model Preset without External `config.json` (#1151)
junrushao Oct 30, 2023
ba67835
Update attention layer (#1153)
junrushao Oct 30, 2023
fee2cb5
Add batched Llama model definition using vLLM paged attention (#1134)
masahi Oct 30, 2023
ece97b1
[Transform][Redo] Apply split_rotary optimization on prefill (#1125)
Lunderberg Oct 30, 2023
b190578
Apply rewrite for normal attention and MQA (#1138)
Lunderberg Oct 30, 2023
8ca0176
[Rest] Fix emoji handling in Rest API. (#1142)
YuchenJin Oct 30, 2023
3cf5605
[Utility] Check for isinstance(exc, Exception) before entering pdb (#…
Lunderberg Oct 30, 2023
0a9d6c7
[Utils] Remove conversion to numpy array in utils.save_params (#1083)
Lunderberg Oct 30, 2023
425a2cb
[Fix][REST] Use lowered-cased "app" (#1159)
junrushao Oct 30, 2023
9076d01
[Rest] Document emoji handling (#1160)
YuchenJin Oct 31, 2023
b5bfa5b
Enable group quant transform with nn.Module (#1154)
cyx-6 Oct 31, 2023
8438b27
Misc Cleanups of Compilation Pipeline (#1165)
junrushao Oct 31, 2023
02d1e57
Support CUDA Multi-Arch Compilation (#1166)
junrushao Oct 31, 2023
e0cd3f6
[Bugfix] Cannot find global function `mlc.llm_chat_create` (#1167)
junrushao Oct 31, 2023
f5b2e88
Fix RWKV Support (#1136)
BBuf Nov 1, 2023
200653a
Auto updated submodule references
Nov 1, 2023
9831135
Fix Android app Permission denied error on Android 10 (#1175)
anibohara2000 Nov 1, 2023
1757777
[SLM] Fix group quantization (#1172)
cyx-6 Nov 1, 2023
2ca7d15
[Fix] TIR block name of dequantization (#1177)
junrushao Nov 2, 2023
53060af
[SLM][AutoLLM] Enable Command Line Weight Conversion (#1170)
zxybazh Nov 2, 2023
2dc8183
[Fix][SLM] Update q4f16 quantization with the new mutator name rule (…
LeshengJin Nov 3, 2023
6ae02dd
[Model Support][SWA] Add support for sliding window attention for Mis…
CharlieFRuan Nov 3, 2023
4716704
Add Python API for Weight Conversion (#1182)
junrushao Nov 4, 2023
9d20575
Merge `llama_config.CONFIG` into `MODEL_PRESETS` (#1188)
junrushao Nov 4, 2023
5d1dc34
Merge llama_config.py into llama_model.py (#1189)
junrushao Nov 4, 2023
4832c2f
Add CodeLlama as part of model presets (#1190)
junrushao Nov 4, 2023
78424f0
[Docs] Clarify zstd installation on Windows (#1191)
junrushao Nov 4, 2023
5d63f7e
[Docs] Clarify zstd installation on Windows (#1196)
junrushao Nov 4, 2023
3417505
Support overriding `--max-sequence-length` in command line (#1197)
junrushao Nov 5, 2023
0e08845
[RestAPI] Added docs (#1193)
anibohara2000 Nov 5, 2023
145a984
[API] ```llm-vscode``` extension support (#1198)
davidpissarra Nov 5, 2023
3413d17
[Fix] Use `fabs` as floating point abs function in C++ (#1202)
junrushao Nov 5, 2023
7ccb51a
Integrating MLC runtime with the new compilation workflow (#1203)
junrushao Nov 6, 2023
65478c8
[Fix] Remove Redundant Warnings (#1204)
junrushao Nov 6, 2023
01d4339
Try fix macOS build with picojson (#1206)
junrushao Nov 6, 2023
51d6f9c
Try fix macOS build with picojson again (#1207)
junrushao Nov 6, 2023
a7f1183
Auto updated submodule references
Nov 6, 2023
e2c99a8
[Fix] Keep update-to-date with upstream API change (#1209)
junrushao Nov 6, 2023
e00220c
Detect `mtriple` via LLVM (#1211)
junrushao Nov 6, 2023
9869ca6
Fix Python3.8 compatibility breakage (#1210)
Lunderberg Nov 6, 2023
4042626
[Slim-LM] Enable loading from AWQ pre-quantized weight. (#1114)
LeshengJin Nov 6, 2023
be1c18b
[Bugfix] Fix Cannot import name '_LIB' from 'mlc_chat.base' (#1214)
CharlieFRuan Nov 7, 2023
1015aae
[SLM] Support `q3f16_1` and `q4f32_1` (#1215)
cyx-6 Nov 8, 2023
1a6fadd
Make the Compilation Working E2E (#1218)
junrushao Nov 8, 2023
616ca42
[Mistral][SWA] Add sliding window to metadata (#1217)
CharlieFRuan Nov 8, 2023
e52f449
Support for `chatml` format conversation (for TinyLlama-1.1B-Chat-v0.…
acalatrava Nov 8, 2023
fbe75e3
Add Rust Support for MLC-LLM (#1213)
YuchenJin Nov 8, 2023
beca2ab
[Bugfix] Remove dependency on openai_api in chat module (#1222)
CharlieFRuan Nov 8, 2023
9ee5705
Bake in RAM Usage in the Generated DSO (#1224)
junrushao Nov 8, 2023
069181c
[Fix] ChatModule python messages and offset types (#1220)
YuchenJin Nov 8, 2023
f1bc951
[Fix] Variable Upperbound Should be Injected before Build Pipeline (#…
junrushao Nov 8, 2023
834811f
[MultiGPU] Support pre-sharded model weights (#1096)
Lunderberg Nov 9, 2023
d41ad34
Merge remote-tracking branch 'mlc-ai/main' into merge-upstream-nov11
masahi Nov 9, 2023
b022dc2
fix
masahi Nov 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,27 @@ namespace mlc {
namespace llm {
namespace {

Conversation ChatML() {
Conversation conv;
conv.name = "chatml";
conv.roles = {"<|im_start|>user", "<|im_start|>assistant"};
conv.system =
("<|im_start|>system A conversation between a user and an LLM-based AI assistant. The "
"assistant gives helpful and honest answers.<|im_end|> ");
conv.messages = {};
conv.offset = 0;
conv.separator_style = SeparatorStyle::kSepRoleMsg;
conv.seps = {"<|im_end|>", "<|im_end|>"};
conv.role_msg_sep = "\n";
conv.role_empty_sep = "\n";
// TODO(mlc-team): add eos to mlc-chat-config
// and remove eos from stop token setting.
conv.stop_tokens = {2};
conv.stop_str = "<|im_end|>";
conv.add_bos = true;
return conv;
}

Conversation LlamaDefault() {
Conversation conv;
conv.name = "llama_default";
Expand Down Expand Up @@ -583,6 +604,7 @@ using ConvFactory = Conversation (*)();

Conversation Conversation::FromTemplate(const std::string& name) {
static std::unordered_map<std::string, ConvFactory> factory = {
{"chatml", ChatML},
{"llama_default", LlamaDefault},
{"llama-2", Llama2},
{"mistral_default", MistralDefault},
Expand Down
15 changes: 8 additions & 7 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,16 @@ class LLMChat {
CHECK(!config.count("max_window_size"))
<< "Cannot specify both sliding_window and max_window_size.";
this->sliding_window_ = config["sliding_window"].get<int64_t>();
CHECK(this->sliding_window_ > 0) << "Sliding window size needs to be positive";
CHECK(config.count("sliding_window_chunk_size"))
<< "Need to specify chunk size if using sliding window attention.";
}
if (config.count("sliding_window_chunk_size")) {
CHECK(config["sliding_window_chunk_size"].is<int64_t>());
this->sliding_window_chunk_size_ = config["sliding_window_chunk_size"].get<int64_t>();
CHECK(this->sliding_window_chunk_size_ > 0)
<< "Sliding window chunk size needs to be positive";
CHECK(config.count("sliding_window")) << "Need to specify sliding window size.";
}
if (config.count("model_name")) {
CHECK(config["model_name"].is<std::string>());
Expand Down Expand Up @@ -828,13 +834,8 @@ class LLMChat {
NDArray logits_on_device;
if (this->sliding_window_ != -1) {
// Use chunking if we use sliding window attention (see Mistral paper figure 3).
int64_t sliding_window_chunk_size = this->sliding_window_chunk_size_;
if (this->sliding_window_chunk_size_ == -1) {
// One chunk if chunk size not specified
sliding_window_chunk_size = token_len;
}
for (int64_t begin = 0; begin < token_len; begin += sliding_window_chunk_size) {
int64_t end = std::min(token_len, begin + sliding_window_chunk_size);
for (int64_t begin = 0; begin < token_len; begin += this->sliding_window_chunk_size_) {
int64_t end = std::min(token_len, begin + this->sliding_window_chunk_size_);
std::vector<int32_t> chunk =
std::vector<int32_t>(prompt_tokens.begin() + begin, prompt_tokens.begin() + end);
new_seq_len += static_cast<int64_t>(chunk.size());
Expand Down
2 changes: 1 addition & 1 deletion mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def build_model_from_args(args: argparse.Namespace):

mod = mod_transform_before_build(mod, param_manager, args, model_config)
if args.num_shards > 1:
# We requires a "create_sharding_info" function for all
# We require a "create_sharding_info" function for all
# multi-GPU models, even if they are using pre-sharded
# weights. When using pre-sharded weights, the list of
# initialization-time transforms to apply is empty.
Expand Down
5 changes: 5 additions & 0 deletions mlc_llm/relax_model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,9 @@ def get_model(args, hf_config):
sliding_window_chunk_size=args.sliding_window_chunk_size,
)

assert config.sliding_window != -1
assert config.sliding_window_chunk_size != -1

param_manager = ParamManager()
bb = relax.BlockBuilder()

Expand All @@ -962,6 +965,8 @@ def get_model(args, hf_config):
max_window_size=config.max_sequence_length,
stop_tokens=[2],
add_prefix_space=False,
sliding_window=config.sliding_window,
sliding_window_chunk_size=config.sliding_window_chunk_size,
)

mod = bb.get()
Expand Down
50 changes: 29 additions & 21 deletions python/mlc_chat/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import warnings
from dataclasses import asdict, dataclass, fields
from enum import Enum
from typing import List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import tvm
from tvm.runtime import disco # pylint: disable=unused-import

from .base import _LIB # pylint: disable=unused-import
from .interface.openai_api import ChatMessage
from . import base # pylint: disable=unused-import

if TYPE_CHECKING:
from .interface.openai_api import ChatMessage

# pylint: disable=line-too-long
_PYTHON_GET_STARTED_TUTORIAL_URL = "https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_chat_module_getting_started.ipynb"
Expand All @@ -41,10 +43,10 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes
roles : Optional[List[str]]
An array that describes the role names of the user and the model. These
names are specific to the model being used.
messages : Optional[List[str]]
messages : Optional[List[List[str]]]
The chat history represented as an array of string pairs in the following
format: ``[[role_0, msg_0], [role_1, msg_1], ...]``.
offset : Optional[str]
offset : Optional[int]
The offset used to begin the chat from the chat history. When offset
is not ``0``, ``messages[0:offset-1]`` will be encoded.
separator_style : Optional[int]
Expand All @@ -69,7 +71,7 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes
system: Optional[str] = None
roles: Optional[List[str]] = None
messages: Optional[List[List[str]]] = None
offset: Optional[str] = None
offset: Optional[int] = None
separator_style: Optional[int] = None
seps: Optional[List[str]] = None
role_msg_sep: Optional[str] = None
Expand Down Expand Up @@ -787,7 +789,7 @@ def __init__(

def generate(
self,
prompt: Union[str, List[ChatMessage]],
prompt: Union[str, List["ChatMessage"]],
generation_config: Optional[GenerationConfig] = None,
progress_callback=None,
) -> Union[str, List[str]]:
Expand All @@ -797,14 +799,18 @@ def generate(

Parameters
----------
prompt : Union[str, List[ChatMessage]]
prompt: Union[str, List[ChatMessage]]
The user input prompt, i.e. a question to ask the chat module.
It can also be the whole conversation history (list of messages with role and content)
eg: ```[
ChatMessage(role="user", content="Hello, how are you?"),
ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"),
ChatMessage(role="user", content="I'm good too."),
]```
eg:

.. code::

[
ChatMessage(role="user", content="Hello, how are you?"),
ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"),
ChatMessage(role="user", content="I'm good too."),
]
generation_config: Optional[GenerationConfig]
The generation config object to override the ChatConfig generation settings.
progress_callback: object
Expand Down Expand Up @@ -841,8 +847,6 @@ def generate(
if (generation_config is not None) and (generation_config.n is not None):
num_return_sequences = generation_config.n
return_str = False
else:
num_return_sequences = 1

for _ in range(num_return_sequences):
self.reset_chat()
Expand Down Expand Up @@ -1001,7 +1005,7 @@ def _unload(self):

def _prefill(
self,
input: Union[str, List[ChatMessage]], # pylint: disable=redefined-builtin
input: Union[str, List["ChatMessage"]], # pylint: disable=redefined-builtin
decode_next_token: bool = True,
place_in_prompt: PlaceInPrompt = PlaceInPrompt.All,
generation_config: Optional[GenerationConfig] = None,
Expand All @@ -1014,11 +1018,15 @@ def _prefill(
input : Union[str, List[ChatMessage]]
The user input prompt, i.e. a question to ask the chat module.
It can also be the whole conversation history (list of messages with role and content)
eg: ```[
ChatMessage(role="user", content="Hello, how are you?"),
ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"),
ChatMessage(role="user", content="I'm good too."),
]```
eg:

.. code::

[
ChatMessage(role="user", content="Hello, how are you?"),
ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"),
ChatMessage(role="user", content="I'm good too."),
]
decode_next_token : bool
Whether to decode the next token after prefilling.
place_in_prompt: PlaceInPrompt
Expand Down
62 changes: 35 additions & 27 deletions python/mlc_chat/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,39 +52,46 @@ def _attach_auxiliary_methods(
mod: IRModule,
named_params: List[Tuple[str, nn.Parameter]],
args: CompileArgs,
model_config,
) -> None:
def _metadata():
metadata = {
"quantization": args.quantization.name,
"model_type": args.model.name,
"params": [
{
"name": name,
"shape": list(param.shape),
"dtype": param.dtype,
}
for name, param in named_params
],
}
def _get_memory_usage():
return {str(k): int(v) for k, v in mod.attrs["mlc_llm.memory_usage"].items()}

def _get_param_info():
return [
{
"name": name,
"shape": list(param.shape),
"dtype": param.dtype,
}
for name, param in named_params
]

def _emit_metadata(metadata):
bb = relax.BlockBuilder() # pylint: disable=invalid-name
with bb.function("main", params=[]):
bb.emit_func_output(relax.StringImm(json.dumps(metadata)))
return bb.get()["main"]

def _attach_variable_bounds():
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
mod[g_var] = func.with_attr(
"tir_var_upper_bound",
{
"seq_len": model_config.max_sequence_length,
"total_seq_len": model_config.max_sequence_length,
},
)
mod["_metadata"] = _emit_metadata(
metadata={
"quantization": args.quantization.name,
"model_type": args.model.name,
"memory_usage": _get_memory_usage(),
"params": _get_param_info(),
}
)


mod["_metadata"] = _metadata()
_attach_variable_bounds()
def _attach_variable_bounds(mod, model_config):
for g_var, func in mod.functions_items():
if isinstance(func, relax.Function):
mod[g_var] = func.with_attr(
"tir_var_upper_bound",
{
"seq_len": model_config.max_sequence_length,
"total_seq_len": model_config.max_sequence_length,
},
)


def _compile(args: CompileArgs):
Expand All @@ -96,10 +103,11 @@ def _compile(args: CompileArgs):
mod, named_params = model.export_tvm(
spec=model.get_default_spec(), # type: ignore
)
_attach_auxiliary_methods(mod, named_params, args, model_config)
logger.info("Running optimizations using TVM Unity")
_attach_variable_bounds(mod, model_config)
with args.target:
mod = relax.get_pipeline("mlc_llm")(mod)
_attach_auxiliary_methods(mod, named_params, args)
logger.info("Generating code using TVM Unity")
args.build_func(mod, args)
logger.info("Generated: %s", bold(str(args.output)))
Expand Down
77 changes: 77 additions & 0 deletions python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Memory usage estimation analysis function for Relax functions."""
from typing import Dict

import tvm
from tvm import relax
from tvm.ir import IRModule, Op
from tvm.relax.expr_functor import PyExprVisitor, visitor


@tvm.transform.module_pass(opt_level=0, name="EstimateMemoryUsage")
class EstimateMemoryUsage: # pylint: disable=too-few-public-methods
"""A pass that attaches the memory usage information as an IRModule attribute.

This pass relies on static analysis on each TVM Relax function in the specific IRModule.
It simply accumulates all memory allocation calls in a function, and does not consider
more dynamic runtime features like control flo "if" or function calls.
"""

def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""Entry point of the pass."""
lowered_mod = tvm.transform.Sequential(
[
relax.transform.RewriteDataflowReshape(),
relax.transform.ToNonDataflow(),
relax.transform.RemovePurityChecking(),
relax.transform.CallTIRRewrite(),
relax.transform.StaticPlanBlockMemory(),
],
name="relax.lower",
)(mod)
usage = _MemoryEstimator().run(lowered_mod)
return mod.with_attr("mlc_llm.memory_usage", usage)


@visitor
class _MemoryEstimator(PyExprVisitor):
"""The IR visitor which estimates the memory usage of each Relax function."""

def __init__(self) -> None:
self.planned_alloc_mem = 0
self.planned_mem_num = 0
self._op_alloc_tensor = Op.get("relax.builtin.alloc_tensor")
self._op_alloc_storage = Op.get("relax.memory.alloc_storage")

def run(self, mod: IRModule) -> Dict[str, int]:
"""Entry point of the visitor."""
result: Dict[str, int] = {}
for global_var, func in mod.functions_items():
if isinstance(func, relax.Function):
self.planned_alloc_mem = 0
self.planned_mem_num = 0
self.visit_expr(func)
result[global_var.name_hint] = self.planned_alloc_mem
return result

def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed
if call.op == self._op_alloc_tensor:
self._builtin_tensor_alloc(shape=call.args[0], dtype_str=call.args[1].value)
elif call.op == self._op_alloc_storage:
self._storage_alloc(size=call.args[0])
super().visit_call_(call)

def _builtin_tensor_alloc(self, shape: relax.Expr, dtype_str: str) -> None:
assert isinstance(shape, relax.ShapeExpr)
size = 1
for dim_len in shape.values:
if not isinstance(dim_len, tvm.tir.IntImm):
return
size *= dim_len.value
dtype = tvm.DataType(dtype_str)
self.planned_mem_num += 1
self.planned_alloc_mem += size * ((dtype.bits + 7) // 8) * dtype.lanes

def _storage_alloc(self, size: relax.Expr) -> None:
assert isinstance(size, relax.ShapeExpr)
self.planned_mem_num += 1
self.planned_alloc_mem += size.values[0].value
2 changes: 2 additions & 0 deletions python/mlc_chat/compiler/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tvm.relax import register_pipeline # pylint: disable=no-name-in-module

from .clean_up_tir_attrs import CleanUpTIRAttrs
from .estimate_memory_usage import EstimateMemoryUsage
from .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise
from .fuse_dequantize_take import FuseDequantizeTake
from .fuse_dequantize_transpose import FuseDequantizeTranspose
Expand Down Expand Up @@ -64,6 +65,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
_LogProgress("Running memory optimizations"),
LiftTIRGlobalBufferAlloc(),
tvm.tir.transform.ForceNarrowIndexToInt32(),
EstimateMemoryUsage(),
]
)
mod = seq(mod._move()) # pylint: disable=protected-access
Expand Down
Loading