Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MultiGPU] Support pre-sharded model weights #1096

Merged
merged 9 commits into from
Nov 9, 2023
20 changes: 17 additions & 3 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,19 +175,25 @@ struct FunctionTable {
}
}

ObjectRef LoadParams(const std::string& model_path, Device device) {
ObjectRef LoadParams(const std::string& model_path, Device device, bool use_presharded_weights) {
if (this->use_disco) {
std::filesystem::path fs_model_path = model_path;
std::string metadata_path = (fs_model_path / "ndarray-cache.json").string();
std::string ndarray_cache_metadata = LoadBytesFromFile(metadata_path);
PackedFunc loader_create = this->get_global_func("runtime.disco.ShardLoader");
PackedFunc loader_load_all = this->get_global_func("runtime.disco.ShardLoaderLoadAll");

auto load_all_func_name = use_presharded_weights
? "runtime.disco.ShardLoaderLoadAllPresharded"
: "runtime.disco.ShardLoaderLoadAll";
PackedFunc loader_load_all = this->get_global_func(load_all_func_name);
CHECK(loader_create != nullptr);
CHECK(loader_load_all != nullptr);
DRef loader = loader_create(metadata_path, ndarray_cache_metadata, "", this->disco_mod);
DRef params = loader_load_all(loader);
return params;
} else {
CHECK(!use_presharded_weights) << "Use of pre-sharded weights requires more than one GPU";

const PackedFunc* fload_cache = tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.load");
ICHECK(fload_cache) << "TVM runtime cannot find vm.builtin.ndarray_cache.load";
(*fload_cache)(model_path, static_cast<int32_t>(device.device_type), device.device_id);
Expand Down Expand Up @@ -387,6 +393,12 @@ class LLMChat {
} else {
this->num_shards_ = 1;
}
if (config.count("use_presharded_weights")) {
CHECK(config["use_presharded_weights"].is<bool>());
this->use_presharded_weights_ = config["use_presharded_weights"].get<bool>();
} else {
this->use_presharded_weights_ = false;
}
if (config.count("max_window_size")) {
CHECK(config["max_window_size"].is<int64_t>());
this->max_window_size_ =
Expand Down Expand Up @@ -518,7 +530,7 @@ class LLMChat {
<< "Cannot find env function vm.builtin.sample_top_p_from_logits";
fsample_topp_from_logits_ = *fsample_topp_from_logits_ptr;
// Step 5. Load params in nd-array cache.
this->params_ = ft_.LoadParams(model_path, device_);
this->params_ = ft_.LoadParams(model_path, device_, use_presharded_weights_);
// Step 6. KV cache creation.
this->kv_cache_ = ft_.create_kv_cache_func_();
// Step 7. Pre-allocate fixed size ndarray
Expand Down Expand Up @@ -1358,6 +1370,8 @@ class LLMChat {
int64_t vocab_size_;
// number of shards in distributed inference
int64_t num_shards_;
// Load weights that were saved in sharded form
bool use_presharded_weights_;
// shift window fill factor
double shift_fill_factor_{0.3};
// temperature
Expand Down
69 changes: 62 additions & 7 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=missing-docstring, redefined-outer-name, not-callable
import argparse
import functools
import json
import os
import pickle
Expand Down Expand Up @@ -29,7 +30,8 @@
rwkv,
stablelm_3b,
)
from mlc_llm.relax_model.commons import create_shard_info_func
from mlc_llm.relax_model.commons import create_shard_info_func, create_shard_transformation_func
from mlc_llm.relax_model.param_manager import transform_params_for_each_rank, chain_parameter_transforms
from mlc_llm.transform import fuse_split_rotary_embedding, rewrite_attention


Expand Down Expand Up @@ -279,6 +281,13 @@ class BuildArgs:
),
},
)
use_presharded_weights: bool = field(
default=False,
metadata={
"action": "store_true",
"help": "Produce separate weight sets for each shard.",
},
)
use_flash_attn_mqa: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -366,9 +375,14 @@ def _parse_args(parsed) -> argparse.Namespace:
"tvm.contrib.vllm.single_query_cached_kv_attention", True
), "TVM needs to be built with -DUSE_VLLM=ON."

parsed.artifact_path = os.path.join(
parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}"
)
model_name = [
parsed.model,
parsed.quantization.name,
]
if parsed.use_presharded_weights:
model_name.append(f"presharded-{parsed.num_shards}gpu")

parsed.artifact_path = os.path.join(parsed.artifact_path, "-".join(model_name))

return parsed

Expand Down Expand Up @@ -602,6 +616,7 @@ def dump_mlc_chat_config(
config["mean_gen_len"] = mean_gen_len
config["max_gen_len"] = max_gen_len
config["num_shards"] = args.num_shards
config["use_presharded_weights"] = args.use_presharded_weights
config["shift_fill_factor"] = shift_fill_factor
if rwkv_world:
config["tokenizer_files"] = ["tokenizer_model"]
Expand Down Expand Up @@ -741,12 +756,46 @@ def build_model_from_args(args: argparse.Namespace):
qspec_updater.visit_module(mod)

if not args.build_model_only:
parameter_transforms = []

# Run pre-quantization if provided.
args.model_path = param_manager.run_pre_quantize(args.model_path)
param_manager.init_torch_pname_to_bin_name(args.use_safetensors)
parameter_transforms.append(param_manager.create_parameter_transformation())

# Run pre-sharding if required
if args.num_shards > 1 and args.use_presharded_weights:
mod_shard = create_shard_transformation_func(param_manager, args, model_config)
mod_shard = transform_params_for_each_rank(mod_shard, num_shards=args.num_shards)
parameter_transforms.append(mod_shard)

# Chain all parameter transforms together. This allows
# ReorderTransformFunc to be applied to the single
# resulting parameter transformation function.
mod_transform = functools.reduce(chain_parameter_transforms, parameter_transforms)

seq = tvm.ir.transform.Sequential(
[
relax.transform.CanonicalizeBindings(),
relax.transform.EliminateCommonSubexpr(),
relax.transform.DeadCodeElimination(),
# TODO(Lunderberg): Implement
# relax.transform.Simplify() that applies
# canonicalization, CSE, and DCE until
# convergence.
relax.transform.CanonicalizeBindings(),
relax.transform.EliminateCommonSubexpr(),
relax.transform.DeadCodeElimination(),
param_manager.optimize_transform_param_order(),
],
name="SimplifyModTransform",
)

mod_transform = seq(mod_transform)

params = utils.convert_weights(mod_transform, param_manager, params, args)
utils.save_params(params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1)

new_params = utils.convert_weights(param_manager, params, args)
utils.save_params(new_params, args.artifact_path)
if args.model_category != "minigpt":
utils.copy_tokenizer(args)
if args.model_category == "rwkv" or args.model_category == "rwkv_world":
Expand All @@ -772,7 +821,13 @@ def build_model_from_args(args: argparse.Namespace):

mod = mod_transform_before_build(mod, param_manager, args, model_config)
if args.num_shards > 1:
create_shard_info_func(mod, param_manager, args, model_config)
# We require a "create_sharding_info" function for all
# multi-GPU models, even if they are using pre-sharded
# weights. When using pre-sharded weights, the list of
# initialization-time transforms to apply is empty.
sharding_module = create_shard_info_func(param_manager, args, model_config)
mod.update(sharding_module)

with open(cache_path, "wb") as outfile:
pickle.dump(mod, outfile)
print(f"Save a cached module to {cache_path}.")
Expand Down
Loading