Skip to content

Commit

Permalink
[Model Support][SWA] Add support for sliding window attention for Mis…
Browse files Browse the repository at this point in the history
…tral (mlc-ai#1087)

* mistral base

* Add sliding window mask making and its tests

* Small changes for sliding window mask

* Clean up mask making

* Remove kv_seq_len

* Add prefill chunking, handle max window size in SWA

* Add interleave kv

* Temporary fix for kv seq len

* Pass in more shapes to SWA prefill and decode in runtime

* mistral var fix

* Small changes regarding shape passing

* Small fix on chunk size

* Add build args, fix mlc chat config dump

* mistral system prompt
---------

Co-authored-by: David Pissarra <[email protected]>
Co-authored-by: David Pissarra <[email protected]>
  • Loading branch information
3 people authored Nov 3, 2023
1 parent 2dc8183 commit 6ae02dd
Show file tree
Hide file tree
Showing 5 changed files with 1,542 additions and 27 deletions.
4 changes: 4 additions & 0 deletions cpp/conv_templates.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ Conversation Llama2() {
Conversation MistralDefault() {
Conversation conv;
conv.name = "mistral_default";
conv.system =
("[INST] Always assist with care, respect, and truth. Respond with utmost utility yet "
"securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies "
"promote fairness and positivity.");
conv.roles = {"[INST]", "[/INST]"};
conv.messages = {};
conv.offset = 0;
Expand Down
89 changes: 77 additions & 12 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ class LLMChat {
if (ft_.use_disco) {
return false;
}
if (this->sliding_window_ != -1) {
return false;
}
PackedFunc fget_metadata = ft_.mod_get_func("get_metadata");
if (fget_metadata == nullptr) {
return false;
Expand Down Expand Up @@ -369,6 +372,16 @@ class LLMChat {
this->max_window_size_ =
std::min(this->max_window_size_, config["max_window_size"].get<int64_t>());
}
if (config.count("sliding_window")) {
CHECK(config["sliding_window"].is<int64_t>());
CHECK(!config.count("max_window_size"))
<< "Cannot specify both sliding_window and max_window_size.";
this->sliding_window_ = config["sliding_window"].get<int64_t>();
}
if (config.count("sliding_window_chunk_size")) {
CHECK(config["sliding_window_chunk_size"].is<int64_t>());
this->sliding_window_chunk_size_ = config["sliding_window_chunk_size"].get<int64_t>();
}
if (config.count("model_name")) {
CHECK(config["model_name"].is<std::string>());
this->model_name_ = config["model_name"].get<std::string>();
Expand Down Expand Up @@ -462,9 +475,11 @@ class LLMChat {
// so there is no explicit abi dependency on these extra
// classes other than basic tvm runtime.
this->ft_.Init(reload_lib, device_, this->num_shards_);
UpdateMaxWindowSizeFromMetadata();
CHECK(max_window_size_ != std::numeric_limits<int64_t>::max())
<< "Key \"max_window_size\" not found.";
if (this->sliding_window_ == -1) {
UpdateMaxWindowSizeFromMetadata();
CHECK(max_window_size_ != std::numeric_limits<int64_t>::max())
<< "Key \"max_window_size\" not found.";
}
// Step 4. Initialize sample functions.
auto fsample_topp_from_prob_ptr =
tvm::runtime::Registry::Get("vm.builtin.sample_top_p_from_prob");
Expand Down Expand Up @@ -562,7 +577,8 @@ class LLMChat {
std::string all_prompt = GetConcatPrompt(prompts, 0, 0);
std::vector<int32_t> encoded = this->tokenizer_->Encode(all_prompt);
tokens.insert(tokens.end(), encoded.begin(), encoded.end());
if (this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) {
if (this->sliding_window_ != -1 || // There is no max window size if we use sliding window
this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) {
return tokens;
}
// need shift window and re-encode
Expand Down Expand Up @@ -753,6 +769,10 @@ class LLMChat {
if (ft_.use_disco) {
LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model";
}
if (this->sliding_window_ != -1) {
LOG(FATAL)
<< "NotImplementedError: Sliding window attention does not support separate embedding";
}
NDArray embedding = Downcast<NDArray>(
EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str));
PrefillWithEmbedStep(embedding, decode_next_token, generation_config_str);
Expand All @@ -772,8 +792,28 @@ class LLMChat {
}
auto tstart = std::chrono::high_resolution_clock::now();

int32_t new_seq_len = total_seq_len_ + token_len;
NDArray logits_on_device = this->ForwardTokens(prompt_tokens, new_seq_len);
int32_t new_seq_len = total_seq_len_;
NDArray logits_on_device;
if (this->sliding_window_ != -1) {
// Use chunking if we use sliding window attention (see Mistral paper figure 3).
int64_t sliding_window_chunk_size = this->sliding_window_chunk_size_;
if (this->sliding_window_chunk_size_ == -1) {
// One chunk if chunk size not specified
sliding_window_chunk_size = token_len;
}
for (int64_t begin = 0; begin < token_len; begin += sliding_window_chunk_size) {
int64_t end = std::min(token_len, begin + sliding_window_chunk_size);
std::vector<int32_t> chunk =
std::vector<int32_t>(prompt_tokens.begin() + begin, prompt_tokens.begin() + end);
new_seq_len += static_cast<int64_t>(chunk.size());
logits_on_device = this->ForwardTokens(chunk, new_seq_len);
}
ICHECK_EQ(new_seq_len, total_seq_len_ + token_len) << "Expect chunking process all tokens";
} else {
// Otherwise, prefill entire prompt at once.
new_seq_len += token_len;
logits_on_device = this->ForwardTokens(prompt_tokens, new_seq_len);
}
total_seq_len_ = new_seq_len;

if (!decode_next_token) {
Expand Down Expand Up @@ -1111,7 +1151,9 @@ class LLMChat {
}
// max_window_size_ != -1 to handle
// https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/relax_model/rwkv.py#L588-L589
else if (max_window_size_ != -1 && total_seq_len_ >= max_window_size_) {
// sliding_window_ == -1 to make sure we do not stop when using sliding window
else if (max_window_size_ != -1 && sliding_window_ == -1 &&
total_seq_len_ >= max_window_size_) {
stop_triggered_ = true;
}
if (stop_triggered_) {
Expand All @@ -1125,7 +1167,18 @@ class LLMChat {
if (input_tokens.size() > 1 && ft_.prefill_func_.defined()) {
ObjectRef input_data = ft_.CopyToWorker0(this->GetInputTokenNDArray(input_tokens));
ShapeTuple cur_pos_shape = ShapeTuple({cur_pos});
ret = ft_.prefill_func_(input_data, cur_pos_shape, kv_cache_, params_);
if (sliding_window_ == -1) {
ret = ft_.prefill_func_(input_data, cur_pos_shape, kv_cache_, params_);
} else {
// Sliding window attention needs extra shape parameters
int64_t seq_len = static_cast<int64_t>(input_tokens.size());
// Number of elements in the cache
int64_t cache_len = std::min(this->sliding_window_, cur_pos - seq_len);
ShapeTuple cache_len_shape = ShapeTuple({cache_len});
ShapeTuple kv_seq_len_shape = ShapeTuple({cache_len + seq_len});
ret = ft_.prefill_func_(input_data, cur_pos_shape, cache_len_shape, kv_seq_len_shape,
kv_cache_, params_);
}
} else {
// running decode function when prefill is not available
for (int i = 0; i < input_tokens.size(); ++i) {
Expand All @@ -1138,8 +1191,19 @@ class LLMChat {
input_data = ft_.CopyToWorker0(this->GetInputTokenNDArray({input_tokens[i]}));
}
int64_t pos = cur_pos + i + 1 - input_tokens.size();
ShapeTuple pos_shape = ShapeTuple({cur_pos});
ret = ft_.decode_func_(input_data, pos_shape, kv_cache_, params_);
ShapeTuple pos_shape = ShapeTuple({pos});
if (sliding_window_ == -1) {
ret = ft_.decode_func_(input_data, pos_shape, kv_cache_, params_);
} else {
// Sliding window attention needs extra shape parameters
int64_t seq_len = static_cast<int64_t>(input_tokens.size());
// Number of elements in the cache
int64_t cache_len = std::min(this->sliding_window_, pos - seq_len);
ShapeTuple cache_len_shape = ShapeTuple({cache_len});
ShapeTuple kv_seq_len_shape = ShapeTuple({cache_len + seq_len});
ret = ft_.decode_func_(input_data, pos_shape, cache_len_shape, kv_seq_len_shape,
kv_cache_, params_);
}
}
}
if (ft_.use_disco) {
Expand Down Expand Up @@ -1265,9 +1329,10 @@ class LLMChat {
Conversation conversation_;
// total sequence len,
int64_t total_seq_len_{0};
// max window size, mean generation length
// max window size, mean and max generation length, sliding window
// If we use sliding window, max window size is its default max() value
int64_t max_window_size_{std::numeric_limits<int64_t>::max()}, mean_gen_len_{128},
max_gen_len_{512};
max_gen_len_{512}, sliding_window_{-1}, sliding_window_chunk_size_{-1};
// size of the vocab table
int64_t vocab_size_;
// number of shards in distributed inference
Expand Down
75 changes: 60 additions & 15 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
llama,
llama_batched_vllm,
minigpt,
mistral,
param_manager,
rwkv,
stablelm_3b,
Expand Down Expand Up @@ -80,6 +81,13 @@ class BuildArgs:
Build with separated embedding layer, only applicable to LlaMa. This
feature is in testing stage, and will be formally replaced after massive
overhaul of embedding feature for all models and use cases.
sliding_window: int
The sliding window size in sliding window attention (SWA). This optional field
overrides the `sliding_window` in config.json for those models that use SWA.
Currently only useful when compiling Mistral.
sliding_window_chunk_size: int
The chunk size in sliding window attention (SWA) during prefilling. By default,
the chunk size is the same as sliding window. Currently only useful when compiling Mistral.
cc_path: str
``/path/to/cross_compiler_path``; currently only used for cross-compile
for nvidia/jetson device.
Expand Down Expand Up @@ -184,7 +192,10 @@ class BuildArgs:
cc_path: str = field(
default="",
metadata={
"help": "/path/to/cross_compiler_path, Currently only used for cross-compile for nvidia/jetson device."
"help": (
"/path/to/cross_compiler_path, Currently only used for "
"cross-compile for nvidia/jetson device."
)
},
)
system_lib: bool = field(
Expand Down Expand Up @@ -275,6 +286,26 @@ class BuildArgs:
"action": "store_true",
},
)
sliding_window: int = field(
default=-1,
metadata={
"help": (
"The sliding window size in sliding window attention (SWA). "
"This optional field overrides the `sliding_window` in config.json for "
"those models that use SWA. Currently only useful when compiling Mistral."
),
},
)
sliding_window_chunk_size: int = field(
default=-1,
metadata={
"help": (
"The chunk size in sliding window attention (SWA) during prefilling. "
"By default, the chunk size is the same as sliding window. "
"Currently only useful when compiling Mistral."
),
},
)
pdb: bool = field(
default=False,
metadata={
Expand All @@ -286,7 +317,8 @@ class BuildArgs:
default=False,
metadata={
"help": (
"Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True."
"Use vLLM paged KV cache and attention kernel, only relevant when "
"enable_batching=True."
),
"action": "store_true",
},
Expand Down Expand Up @@ -330,7 +362,9 @@ def _parse_args(parsed) -> argparse.Namespace:
if parsed.use_vllm_attention:
assert parsed.enable_batching, "--enable_batching is required for using vLLM attention."
assert parsed.target_kind == "cuda", "vLLM attention is only supported for CUDA."
assert tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True), "TVM needs to be built with -DUSE_VLLM=ON."
assert tvm.get_global_func(
"tvm.contrib.vllm.single_query_cached_kv_attention", True
), "TVM needs to be built with -DUSE_VLLM=ON."

parsed.artifact_path = os.path.join(
parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}"
Expand Down Expand Up @@ -391,10 +425,10 @@ def _setup_model_path(args: argparse.Namespace): # pylint: disable=too-many-bra
def validate_config(model_path: str):
if os.path.exists(os.path.join(model_path, "mlc-chat-config.json")):
raise KeyError(
"The model located in the directory {} has already been compiled by MLC-LLM. There is"
" no need to compile it again. If you wish to compile a new model, please provide a"
" directory (or hf-path) that contains the pre-compiled model in raw HuggingFace"
" format instead.".format(model_path)
f"The model located in the directory {model_path} has already been compiled "
"by MLC-LLM. There is no need to compile it again. If you wish to compile "
"a new model, please provide a directory (or hf-path) that contains the "
"pre-compiled model in raw HuggingFace format instead."
)
if model_path.split("/")[-1].startswith("minigpt"):
# minigpt does not contain a config.json file so we skip the check
Expand Down Expand Up @@ -467,19 +501,21 @@ def mod_transform_before_build(

if max_seq_len:
num_key_value_heads = config.get_num_key_value_heads()
# pylint: disable=no-value-for-parameter
mod = fuse_split_rotary_embedding(
config.num_attention_heads // args.num_shards,
num_key_value_heads // args.num_shards,
config.hidden_size // args.num_shards,
config.position_embedding_base,
)(mod)
config.num_attention_heads // args.num_shards,
num_key_value_heads // args.num_shards,
config.hidden_size // args.num_shards,
config.position_embedding_base,
)(mod)

if args.target_kind == "cuda":
patterns = []

has_cutlass = tvm.get_global_func("relax.ext.cutlass", True)

if has_cutlass and not args.no_cutlass_attn:
# pylint: disable=no-value-for-parameter
if args.use_flash_attn_mqa:
mod = rewrite_attention(use_flash_mqa=True)(mod)
mod = rewrite_attention(use_flash_mqa=False)(mod)
Expand Down Expand Up @@ -565,7 +601,6 @@ def dump_mlc_chat_config(
config["top_p"] = top_p
config["mean_gen_len"] = mean_gen_len
config["max_gen_len"] = max_gen_len
config["max_window_size"] = max_window_size
config["num_shards"] = args.num_shards
config["shift_fill_factor"] = shift_fill_factor
if rwkv_world:
Expand All @@ -575,6 +610,12 @@ def dump_mlc_chat_config(
config["model_category"] = args.model_category
config["model_name"] = args.model
config["vocab_size"] = vocab_size
if args.sliding_window != -1:
# Do not add max window size if use sliding window
config["sliding_window"] = args.sliding_window
config["sliding_window_chunk_size"] = args.sliding_window_chunk_size
else:
config["max_window_size"] = max_window_size

args.chat_config_path = os.path.join(args.params_path, "mlc-chat-config.json")
with open(args.chat_config_path, "w", encoding="utf-8") as outfile:
Expand Down Expand Up @@ -640,7 +681,7 @@ def build_model_from_args(args: argparse.Namespace):
if args.quantization == "q4f16_0":
print(
"WARNING: q4f16_1 is preferred to q4f16_0, "
"and it is highly recommended to use q4f16_1 instaed"
"and it is highly recommended to use q4f16_1 instead"
)
if args.num_shards > 1:
if (not args.build_model_only) and (not args.convert_weight_only):
Expand Down Expand Up @@ -670,7 +711,7 @@ def build_model_from_args(args: argparse.Namespace):
if not use_cache or args.convert_weight_only:
model_generators = {
"llama": llama,
"mistral": llama,
"mistral": mistral,
"stablelm_epoch": stablelm_3b,
"gpt_neox": gpt_neox,
"gpt_bigcode": gpt_bigcode,
Expand All @@ -691,6 +732,10 @@ def build_model_from_args(args: argparse.Namespace):
args, config
)

if args.model_category == "mistral":
args.sliding_window = model_config.sliding_window
args.sliding_window_chunk_size = model_config.sliding_window_chunk_size

for qspec_updater_class in param_manager.qspec_updater_classes:
qspec_updater = qspec_updater_class(param_manager)
qspec_updater.visit_module(mod)
Expand Down
Loading

0 comments on commit 6ae02dd

Please sign in to comment.