From bca068f619a9250b4aa32513c15e40970615a6dd Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Tue, 4 Jun 2024 05:23:13 -0700 Subject: [PATCH 1/3] ggml: avoid rebuild of GGML graph for each token (#7456) Introduces caching of GGML graph to avoid unnecessary full rebuild between each token. KV cache parameters, which change with each token, are updated directly in cached GGML graph. Can be disabled with GGML_DISABLE_GRAPH_CACHING environment variable. --- ggml/include/ggml-backend.h | 5 ++ ggml/src/ggml-backend.c | 33 +++++++++- src/llama.cpp | 122 +++++++++++++++++++++++++++++++++--- 3 files changed, 152 insertions(+), 8 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 4a38eeb5c23bd..1d406dc9d0ee6 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -230,6 +230,11 @@ extern "C" { GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor); + // Utility to query whether cached GGML graph is in use + GGML_API bool ggml_use_cached_graph(ggml_backend_sched_t sched); + + // Set whether or not to use GGML graph caching + GGML_API void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value); #ifdef __cplusplus } diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index 13c71c310c446..a7a61ac34ddcc 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -1036,6 +1036,13 @@ struct ggml_backend_sched_split { struct ggml_cgraph graph; }; +// Object to facilitate GML graph caching +struct ggml_cached_graph { + bool is_active; + ggml_backend_t input_backend; + struct ggml_tensor * input_cpy[GGML_SCHED_MAX_SPLIT_INPUTS]; +}; + struct ggml_backend_sched { bool is_reset; // true if the scheduler has been reset since the last graph split bool is_alloc; @@ -1087,6 +1094,8 @@ struct ggml_backend_sched { __attribute__((aligned(GGML_MEM_ALIGN))) #endif char context_buffer[GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + sizeof(struct ggml_cgraph)]; + + struct ggml_cached_graph cached_graph; }; #define hash_id(tensor) ggml_hash_find_or_insert(sched->hash_set, tensor) @@ -1753,6 +1762,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s struct ggml_tensor * input = split->inputs[j]; struct ggml_tensor * input_cpy = sched->tensor_copies[hash_id(input)][split_backend_id][sched->cur_copy]; + if (!sched->cached_graph.is_active) { + sched->cached_graph.input_backend = input_backend; + sched->cached_graph.input_cpy[j] = input_cpy; + } + else { + input_backend = sched->cached_graph.input_backend; + input_cpy = sched->cached_graph.input_cpy[j]; + } if (input->flags & GGML_TENSOR_FLAG_INPUT) { // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done if (sched->events[split_backend_id][sched->cur_copy] != NULL) { @@ -1872,6 +1889,8 @@ ggml_backend_sched_t ggml_backend_sched_new( ggml_backend_sched_reset(sched); + sched->cached_graph.is_active = false; + return sched; } @@ -1947,6 +1966,9 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st } enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { + + if(!sched->cached_graph.is_active) + { if (!sched->is_reset && !sched->is_alloc) { ggml_backend_sched_reset(sched); } @@ -1956,7 +1978,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch return GGML_STATUS_ALLOC_FAILED; } } - + } return ggml_backend_sched_compute_splits(sched); } @@ -2223,3 +2245,12 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t return true; } + +bool ggml_use_cached_graph(ggml_backend_sched_t sched) { + return sched->cached_graph.is_active; +} + +void ggml_set_cached_graph(ggml_backend_sched_t sched, bool set_value) { + sched->cached_graph.is_active = set_value; +} + diff --git a/src/llama.cpp b/src/llama.cpp index 2b9ace2858457..1e601ac693e5f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2712,6 +2712,16 @@ struct llama_model { } }; +// Object used to allow caching of GGML graph between tokens where possible. +struct ggml_cached_graph { + ggml_cgraph * gf; + size_t n; + ggml_backend_t backend_res; + ggml_backend_t backend_embd; + struct ggml_tensor * res; + struct ggml_tensor * embd; +}; + struct llama_context { llama_context(const llama_model & model) : model(model), t_start_us(model.t_start_us), t_load_us(model.t_load_us) {} ~llama_context() { @@ -2813,6 +2823,8 @@ struct llama_context { // control vectors struct llama_control_vector cvec; + + struct ggml_cached_graph cached_graph; }; static size_t llama_get_device_count(const llama_model & model) { @@ -14524,12 +14536,37 @@ static int llama_decode_internal( ggml_backend_sched_reset(lctx.sched); ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); - + ggml_cgraph * gf; // the output is always the last tensor in the graph - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2]; + struct ggml_tensor * res; + struct ggml_tensor * embd; + + bool n_has_changed_since_last_token = false; + if(lctx.cached_graph.n != kv_self.n) n_has_changed_since_last_token = true; + lctx.cached_graph.n = kv_self.n; + + // Re-build graph only if graph caching is not possible + if(!ggml_use_cached_graph(lctx.sched) || n_has_changed_since_last_token) { + + gf = llama_build_graph(lctx, u_batch, false); + + // disable future graph caching in presense of env var, + // if there are multiple devices, or if batch size is greater than 1 + // TO DO enable graph caching for these cases + bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr) + || (llama_get_device_count(model) > 1); + for (int i = 0 ; i < gf->n_nodes; i++) { + if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) { + disable_cached_ggml_graph = true; + break; + } + } + + if(!disable_cached_ggml_graph) ggml_set_cached_graph(lctx.sched,true); + // the output is always the last tensor in the graph + res = gf->nodes[gf->n_nodes - 1]; + embd = gf->nodes[gf->n_nodes - 2]; if (lctx.n_outputs == 0) { // no output res = nullptr; @@ -14545,10 +14582,71 @@ static int llama_decode_internal( embd = nullptr; // do not extract embeddings when not needed GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); } + lctx.cached_graph.res = res; + lctx.cached_graph.embd = embd; // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); ggml_backend_sched_alloc_graph(lctx.sched, gf); + } + else { + gf = lctx.cached_graph.gf; + res = lctx.cached_graph.res; + embd = lctx.cached_graph.embd; + } + lctx.cached_graph.gf = gf; + + if(ggml_use_cached_graph(lctx.sched)) { + + // If using flash attention, find mask node so it can be skipped when updating + // KV cache paramaters in cached graph nodes below + void * flash_attn_mask_node = nullptr; + if(cparams.flash_attn) { + for (int i = 0; i < gf->n_nodes; i++) { + ggml_tensor * node = gf->nodes[i]; + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + flash_attn_mask_node = node->src[3]; + break; + } + } + } + + // Temporarily store KV cache parameters that will need updated in cached graph. + const struct llama_hparams & hparams = model.hparams; + const int64_t n_layer = hparams.n_layer; + const int64_t kv_head = kv_self.head; + std::vector kv_cache_ptrs; + for (int il = 0; il < n_layer; ++il) { + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + ggml_tensor * tmp_tensor = kv_self.k_l[il]; + size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head; + kv_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + tmp_tensor = kv_self.v_l[il]; + if (cparams.flash_attn) { + tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); + } else { + tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]); + } + kv_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + } + + // Update KV cache parameters in cached graph. + int copy_op_count = 0; + if(gf != nullptr && gf->nodes != nullptr){ + for (int i = 0; i < gf->n_nodes; i++) { + ggml_tensor * node = gf->nodes[i]; + if (node->op == GGML_OP_CPY) { + if (node != flash_attn_mask_node) { + node->src[1]->data = kv_cache_ptrs[copy_op_count]; + copy_op_count++; + } + } + } + } + + } + llama_set_inputs(lctx, u_batch); llama_graph_compute(lctx, gf, n_threads); @@ -14571,11 +14669,15 @@ static int llama_decode_internal( // extract logits if (res) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(lctx.logits != nullptr); - float * logits_out = lctx.logits + n_outputs_prev*n_vocab; const int32_t n_outputs_new = lctx.n_outputs; + if(!ggml_use_cached_graph(lctx.sched)) + lctx.cached_graph.backend_res = backend_res; + else + backend_res = lctx.cached_graph.backend_res; + + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(lctx.logits != nullptr); if (n_outputs_new) { GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs); @@ -14587,6 +14689,12 @@ static int llama_decode_internal( // extract embeddings if (embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd); + + + if(!ggml_use_cached_graph(lctx.sched)) + lctx.cached_graph.backend_embd = backend_embd; + else + backend_embd = lctx.cached_graph.backend_embd; GGML_ASSERT(backend_embd != nullptr); switch (cparams.pooling_type) { From b7956a8532d117ff8918b44b64db7cf93f98b560 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Mon, 8 Jul 2024 08:43:21 -0700 Subject: [PATCH 2/3] fix seg fault --- src/llama.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 1e601ac693e5f..6bd0863c63e8c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2714,6 +2714,7 @@ struct llama_model { // Object used to allow caching of GGML graph between tokens where possible. struct ggml_cached_graph { + bool is_active = false; ggml_cgraph * gf; size_t n; ggml_backend_t backend_res; @@ -14550,7 +14551,11 @@ static int llama_decode_internal( gf = llama_build_graph(lctx, u_batch, false); - // disable future graph caching in presense of env var, + // Set whether GGML graph caching is in use within GGML module, based on + // whether caching was activated here during the previous token + ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active); + + // Disable future graph caching in presence of env var, // if there are multiple devices, or if batch size is greater than 1 // TO DO enable graph caching for these cases bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr) @@ -14562,7 +14567,8 @@ static int llama_decode_internal( } } - if(!disable_cached_ggml_graph) ggml_set_cached_graph(lctx.sched,true); + // Set whether graph caching should be used for future tokens + lctx.cached_graph.is_active=!disable_cached_ggml_graph; // the output is always the last tensor in the graph res = gf->nodes[gf->n_nodes - 1]; From a34900aad194ae0239b3fad94c48dc82b9f3a1a1 Mon Sep 17 00:00:00 2001 From: Alan Gray Date: Wed, 10 Jul 2024 03:29:12 -0700 Subject: [PATCH 3/3] restrict to nsplit=2 --- src/llama.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 6bd0863c63e8c..4a309b999205a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14556,10 +14556,12 @@ static int llama_decode_internal( ggml_set_cached_graph(lctx.sched,lctx.cached_graph.is_active); // Disable future graph caching in presence of env var, - // if there are multiple devices, or if batch size is greater than 1 + // if there are multiple devices, if batch size is greater than 1, + // or if nsplits is not 2. // TO DO enable graph caching for these cases bool disable_cached_ggml_graph = (getenv("GGML_DISABLE_GRAPH_CACHING") != nullptr) - || (llama_get_device_count(model) > 1); + || (llama_get_device_count(model) > 1) + || (ggml_backend_sched_get_n_splits(lctx.sched) != 2); for (int i = 0 ; i < gf->n_nodes; i++) { if (gf->nodes[i]->op == GGML_OP_ADD && gf->nodes[i]->src[1] && gf->nodes[i]->src[1]->ne[1] > 1) { disable_cached_ggml_graph = true;