Skip to content

Commit

Permalink
[SLM] Batched Llama
Browse files Browse the repository at this point in the history
This PR introduces the batched llama modeling with Paged KV cache
in SLM flow.
  • Loading branch information
MasterJH5574 committed Jan 4, 2024
1 parent 073e007 commit 319ce79
Show file tree
Hide file tree
Showing 9 changed files with 544 additions and 39 deletions.
15 changes: 12 additions & 3 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ class LLMChat {
// Step 6. KV cache creation.
this->kv_cache_ = ft_.create_kv_cache_func_();
// Step 7. Pre-allocate fixed size ndarray
this->temperature_arr_ = NDArray::Empty({}, DataType::Float(32), device_);
this->temperature_arr_ = NDArray::Empty({1}, DataType::Float(32), device_);
float temperature = static_cast<float>(this->temperature_);
this->temperature_arr_.CopyFromBytes(&temperature, sizeof(float));
if (ft_.use_disco) {
Expand Down Expand Up @@ -1081,7 +1081,7 @@ class LLMChat {
CHECK(generation_config["temperature"].is<double>());
*gen_temperature = generation_config["temperature"].get<double>();

*gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_);
*gen_temperature_arr = NDArray::Empty({1}, DataType::Float(32), device_);
float temperature_cast = static_cast<float>(*gen_temperature);
gen_temperature_arr->CopyFromBytes(&temperature_cast, sizeof(float));
} else {
Expand Down Expand Up @@ -1331,7 +1331,16 @@ class LLMChat {

NDArray Softmax(NDArray input, NDArray temperature_arr) {
NDArray ret;
ret = ft_.softmax_func_(input, temperature_arr);
try {
ret = ft_.softmax_func_(input, temperature_arr);
} catch (const dmlc::Error& e) {
// This branch is for compatibility:
// The old softmax function takes temperature arr with shape (),
// and the new softmax func takes temperature arr with shape (1,).
// Remove this branch after updating all prebuilt model libraries.
temperature_arr = temperature_arr.CreateView({}, temperature_arr->dtype);
ret = ft_.softmax_func_(input, temperature_arr);
}
return ret;
}

Expand Down
7 changes: 7 additions & 0 deletions python/mlc_chat/cli/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def _parse_output(path: Union[str, Path]) -> Path:
default=None,
help=HELP["tensor_parallel_shards"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--max-batch-size",
type=int,
default=80,
help=HELP["max_batch_size"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--output",
"-o",
Expand All @@ -95,5 +101,6 @@ def _parse_output(path: Union[str, Path]) -> Path:
prefill_chunk_size=parsed.prefill_chunk_size,
attention_sink_size=parsed.attention_sink_size,
tensor_parallel_shards=parsed.tensor_parallel_shards,
max_batch_size=parsed.max_batch_size,
output=parsed.output,
)
3 changes: 3 additions & 0 deletions python/mlc_chat/help.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@
"attention_sink_size": """
(Experimental) The number of stored sinks. Only supported on Mistral yet. By default,
the number of sinks is 4. This flag subjects to future refactoring.
""".strip(),
"max_batch_size": """
The maximum allowed batch size set for batch prefill/decode function.
""".strip(),
"""tensor_parallel_shards""": """
Number of shards to split the model into in tensor parallelism multi-gpu inference.
Expand Down
18 changes: 9 additions & 9 deletions python/mlc_chat/interface/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,16 @@ def _apply_preproc_to_params(

def _compile(args: CompileArgs, model_config: ConfigBase):
def _get_variable_bounds(model_config) -> Dict[str, int]:
variable_bounds = {"seq_len": model_config.prefill_chunk_size}
if hasattr(model_config, "sliding_window_size"):
return {
"seq_len": model_config.prefill_chunk_size,
"rolling_cache_len": model_config.sliding_window_size,
"kv_seq_len": model_config.sliding_window_size + model_config.prefill_chunk_size,
}
return {
"seq_len": model_config.prefill_chunk_size,
"total_seq_len": model_config.context_window_size,
}
variable_bounds["rolling_cache_len"] = model_config.sliding_window_size
variable_bounds["kv_seq_len"] = (
model_config.sliding_window_size + model_config.prefill_chunk_size,
)
else:
variable_bounds["total_seq_len"] = model_config.context_window_size
variable_bounds["batch_size"] = getattr(model_config, "max_batch_size", 1)
return variable_bounds

def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]:
return {
Expand Down
3 changes: 3 additions & 0 deletions python/mlc_chat/interface/gen_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class MLCChatConfig: # pylint: disable=too-many-instance-attributes
prefill_chunk_size: int
attention_sink_size: int
tensor_parallel_shards: int
max_batch_size: int
# Control the behavior of the runtime
mean_gen_len: int = None
max_gen_len: int = None
Expand Down Expand Up @@ -79,6 +80,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
prefill_chunk_size: Optional[int],
attention_sink_size: Optional[int],
tensor_parallel_shards: Optional[int],
max_batch_size: int,
output: Path,
):
"""Entrypoint of MLC Chat configuration generation."""
Expand All @@ -100,6 +102,7 @@ def gen_config( # pylint: disable=too-many-locals,too-many-arguments,too-many-b
prefill_chunk_size=model_config.prefill_chunk_size,
attention_sink_size=getattr(model_config, "attention_sink_size", -1),
tensor_parallel_shards=model_config.tensor_parallel_shards,
max_batch_size=max_batch_size,
conv_template=conv_template,
)
# Step 2. Load `generation_config.json` and `config.json` for text-generation related configs
Expand Down
Loading

0 comments on commit 319ce79

Please sign in to comment.