Skip to content

Commit

Permalink
llama : accept a list of devices to use to offload a model
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Nov 25, 2024
1 parent 9ca2e67 commit f4457cb
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 24 deletions.
57 changes: 53 additions & 4 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,40 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
else { throw std::invalid_argument("invalid value"); }
}
).set_env("LLAMA_ARG_NUMA"));
add_opt(common_arg(
{"-dev", "--device"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading\n"
"use --list-devices to see a list of available devices",
[](common_params & params, const std::string & value) {
auto devices = string_split<std::string>(value, ',');
if (devices.empty()) {
throw std::invalid_argument("no devices specified");
}
for (const auto & device : devices) {
auto * dev = ggml_backend_dev_by_name(device.c_str());
if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
throw std::invalid_argument(string_format("invalid device: %s", device.c_str()));
}
params.devices.push_back(dev);
}
params.devices.push_back(nullptr);
}
).set_env("LLAMA_ARG_DEVICES"));
add_opt(common_arg(
{"--list-devices"},
"print list of available devices and exit",
[](common_params &) {
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
printf("%s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
}
}
exit(0);
}
));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
Expand All @@ -1336,10 +1370,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} else if (arg_next == "layer") {
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
} else if (arg_next == "row") {
#ifdef GGML_USE_SYCL
fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n");
exit(1);
#endif // GGML_USE_SYCL
params.split_mode = LLAMA_SPLIT_MODE_ROW;
} else {
throw std::invalid_argument("invalid value");
Expand Down Expand Up @@ -2042,6 +2072,25 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.n_ctx = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model\n"
"use --list-devices to see a list of available devices",
[](common_params & params, const std::string & value) {
auto devices = string_split<std::string>(value, ',');
if (devices.empty()) {
throw std::invalid_argument("no devices specified");
}
for (const auto & device : devices) {
auto * dev = ggml_backend_dev_by_name(device.c_str());
if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
throw std::invalid_argument(string_format("invalid device: %s", device.c_str()));
}
params.speculative.devices.push_back(dev);
}
params.speculative.devices.push_back(nullptr);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model",
Expand Down
5 changes: 4 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -982,9 +982,12 @@ void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_l
}
}

struct llama_model_params common_model_params_to_llama(const common_params & params) {
struct llama_model_params common_model_params_to_llama(common_params & params) {
auto mparams = llama_model_default_params();

if (!params.devices.empty()) {
mparams.devices = params.devices.data();
}
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}
Expand Down
14 changes: 9 additions & 5 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ struct common_params_sampling {
};

struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
Expand All @@ -178,9 +179,6 @@ struct common_params {
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t grp_attn_n = 1; // group-attention factor
int32_t grp_attn_w = 512; // group-attention width
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
Expand All @@ -193,6 +191,13 @@ struct common_params {
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = 0.1f; // KV cache defragmentation threshold

// offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs

struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;

Expand All @@ -201,7 +206,6 @@ struct common_params {

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
Expand Down Expand Up @@ -462,7 +466,7 @@ struct common_init_result {

struct common_init_result common_init_from_params(common_params & params);

struct llama_model_params common_model_params_to_llama (const common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);

Expand Down
1 change: 1 addition & 0 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ int main(int argc, char ** argv) {
ctx_tgt = llama_init_tgt.context;

// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.model;
params.n_gpu_layers = params.speculative.n_gpu_layers;
if (params.speculative.cpuparams.n_threads > 0) {
Expand Down
13 changes: 11 additions & 2 deletions ggml/src/ggml-backend-reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,15 @@ void ggml_backend_device_register(ggml_backend_dev_t device) {
}

// Backend (reg) enumeration
static bool striequals(const char * a, const char * b) {
for (; *a && *b; a++, b++) {
if (std::tolower(*a) != std::tolower(*b)) {
return false;
}
}
return *a == *b;
}

size_t ggml_backend_reg_count() {
return get_reg().backends.size();
}
Expand All @@ -265,7 +274,7 @@ ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
ggml_backend_reg_t reg = ggml_backend_reg_get(i);
if (std::strcmp(ggml_backend_reg_name(reg), name) == 0) {
if (striequals(ggml_backend_reg_name(reg), name)) {
return reg;
}
}
Expand All @@ -285,7 +294,7 @@ ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (strcmp(ggml_backend_dev_name(dev), name) == 0) {
if (striequals(ggml_backend_dev_name(dev), name)) {
return dev;
}
}
Expand Down
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ extern "C" {
};

struct llama_model_params {
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
ggml_backend_dev_t * devices;

int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs

Expand Down
30 changes: 18 additions & 12 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19364,6 +19364,7 @@ void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
//
struct llama_model_params llama_model_default_params() {
struct llama_model_params result = {
/*.devices =*/ nullptr,
/*.n_gpu_layers =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
Expand Down Expand Up @@ -19576,19 +19577,24 @@ struct llama_model * llama_load_model_from_file(
}

// create list of devices to use with this model
// currently, we use all available devices
// TODO: rework API to give user more control over device selection
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
switch (ggml_backend_dev_type(dev)) {
case GGML_BACKEND_DEVICE_TYPE_CPU:
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
// skip CPU backends since they are handled separately
break;
if (params.devices) {
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
model->devices.push_back(*dev);
}
} else {
// use all available devices
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
switch (ggml_backend_dev_type(dev)) {
case GGML_BACKEND_DEVICE_TYPE_CPU:
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
// skip CPU backends since they are handled separately
break;

case GGML_BACKEND_DEVICE_TYPE_GPU:
model->devices.push_back(dev);
break;
case GGML_BACKEND_DEVICE_TYPE_GPU:
model->devices.push_back(dev);
break;
}
}
}

Expand Down

0 comments on commit f4457cb

Please sign in to comment.