From f009e40c0dca1d07c648c157cf9ed2a665e969af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 20 Apr 2024 08:24:21 +0200 Subject: [PATCH] Server: enable lookup decoding --- Makefile | 2 +- common/common.cpp | 4 +- common/ngram-cache.cpp | 27 ++- common/ngram-cache.h | 16 +- examples/lookup/README.md | 80 +++++++- examples/lookup/lookup-create.cpp | 2 +- examples/lookup/lookup-merge.cpp | 6 +- examples/lookup/lookup-stats.cpp | 19 +- examples/lookup/lookup.cpp | 18 +- examples/server/README.md | 6 +- examples/server/bench/bench.py | 8 + examples/server/server.cpp | 194 ++++++++++++++---- .../server/tests/features/results.feature | 30 ++- examples/server/tests/features/steps/steps.py | 14 +- 14 files changed, 331 insertions(+), 95 deletions(-) diff --git a/Makefile b/Makefile index 0a73f2a582a204..59afb85ceee7fb 100644 --- a/Makefile +++ b/Makefile @@ -825,7 +825,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/server/json-schema-to-grammar.mjs.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS) +server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/server/json-schema-to-grammar.mjs.hpp common/stb_image.h ggml.o llama.o ngram-cache.o $(COMMON_DEPS) grammar-parser.o $(OBJS) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) diff --git a/common/common.cpp b/common/common.cpp index 243b88abf1aab4..ffd1161937d8a8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1584,9 +1584,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ld LOGDIR, --logdir LOGDIR\n"); printf(" path under which to save YAML logs (no logging if unset)\n"); printf(" -lcs FNAME, --lookup-cache-static FNAME\n"); - printf(" path to static lookup cache to use for lookup decoding (not updated by generation)\n"); + printf(" path to static lookup cache to use for n-gram lookup decoding (not updated by generation)\n"); printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n"); - printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n"); + printf(" path to dynamic lookup cache to use for n-gram lookup decoding (updated by generation)\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" advanced option to override model metadata by key. may be specified multiple times.\n"); printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); diff --git a/common/ngram-cache.cpp b/common/ngram-cache.cpp index 3ca112ef1613d8..3247643320e893 100644 --- a/common/ngram-cache.cpp +++ b/common/ngram-cache.cpp @@ -6,19 +6,18 @@ #include void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max, - std::vector & inp, int nnew, bool print_progress) { + llama_token * inp_data, int inp_size, int nnew, bool print_progress) { const int64_t t_start_ms = ggml_time_ms(); - const int64_t inp_size = inp.size(); const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1); int64_t n_done = 0; for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) { - const int64_t i_start = std::max(inp_size - nnew, ngram_size); + const int64_t i_start = std::max((int64_t)(inp_size - nnew), ngram_size); for (int64_t i = i_start; i < inp_size; ++i) { const int64_t ngram_start = i - ngram_size; - llama_ngram ngram(&inp[ngram_start], ngram_size); - const llama_token token = inp[i]; + llama_ngram ngram(inp_data + ngram_start, ngram_size); + const llama_token token = inp_data[i]; llama_ngram_cache::iterator part_it = ngram_cache.find(ngram); if (part_it == ngram_cache.end()) { @@ -48,8 +47,8 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in } // Helper function to get a token from the combined, speculative sequence of inp and draft. -static llama_token get_token(const std::vector & inp, const std::vector & draft, const size_t i) { - return i < inp.size() ? inp[i] : draft[1 + i - inp.size()]; +static llama_token get_token(const llama_token * inp_data, const int inp_size, const std::vector & draft, const int i) { + return i < inp_size ? inp_data[i] : draft[1 + i - inp_size]; } // If sample size or percentage are below these thresholds the draft is aborted early: @@ -140,11 +139,10 @@ static llama_token try_draft( } void llama_ngram_cache_draft( - std::vector & inp, std::vector & draft, int n_draft, int ngram_min, int ngram_max, + llama_token * inp_data, int inp_size, std::vector & draft, int n_draft, int ngram_min, int ngram_max, llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static ) { GGML_ASSERT(draft.size() == 1); - const int inp_size = inp.size(); if (inp_size < LLAMA_NGRAM_STATIC) { return; @@ -156,7 +154,7 @@ void llama_ngram_cache_draft( const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1; llama_ngram ngram_static; for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) { - ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j); + ngram_static.tokens[j-ngram_start_static] = get_token(inp_data, inp_size, draft, j); } llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static); llama_ngram_cache_part part_static; @@ -170,7 +168,7 @@ void llama_ngram_cache_draft( const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1; llama_ngram ngram_cd; for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) { - ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j); + ngram_cd.tokens[j-ngram_start_cd] = get_token(inp_data, inp_size, draft, j); } ngrams_cd.push_back(ngram_cd); } @@ -216,12 +214,11 @@ void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filen } -llama_ngram_cache llama_ngram_cache_load(std::string & filename) { +bool llama_ngram_cache_load(llama_ngram_cache & ngram_cache, std::string & filename) { std::ifstream hashmap_file(filename, std::ios::binary); if (!hashmap_file) { - throw std::ifstream::failure("Unable to open file " + filename); + return false; } - llama_ngram_cache ngram_cache; llama_ngram ngram; int32_t ntokens; @@ -251,7 +248,7 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) { } GGML_ASSERT(hashmap_file.eof()); - return ngram_cache; + return true; } void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) { diff --git a/common/ngram-cache.h b/common/ngram-cache.h index e4fa4cbd12f11e..6575ea05fa6b9a 100644 --- a/common/ngram-cache.h +++ b/common/ngram-cache.h @@ -39,10 +39,13 @@ struct llama_ngram { struct llama_ngram_hash_function { size_t operator()(const llama_ngram & ngram) const { - size_t hash = 0; - for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) { - hash ^= std::hash{}(ngram.tokens[i]); + size_t hash = ngram.tokens[0]; + + for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) { + hash <<= 15; + hash ^= ngram.tokens[i]; } + return hash; } }; @@ -64,7 +67,7 @@ typedef std::unordered_map & inp_data, int nnew, bool print_progress); + llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max, llama_token * inp_data, int inp_size, int nnew, bool print_progress); // Try to draft tokens from ngram caches. // inp: the tokens generated so far. @@ -75,7 +78,7 @@ void llama_ngram_cache_update( // nc_dynamic: ngram cache based on previous user generations. // nc_static: ngram cache generated from a large text corpus, used for validation. void llama_ngram_cache_draft( - std::vector & inp, std::vector & draft, int n_draft, int ngram_min, int ngram_max, + llama_token * inp_data, int inp_size, std::vector & draft, int n_draft, int ngram_min, int ngram_max, llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static); // Save an ngram cache to a file. @@ -84,9 +87,10 @@ void llama_ngram_cache_draft( void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename); // Load an ngram cache saved with llama_ngram_cache_save. +// ngram_cache: the ngram cache to load the data into. // filename: the path from which to load the ngram cache. // returns: an ngram cache containing the information saved to filename. -llama_ngram_cache llama_ngram_cache_load(std::string & filename); +bool llama_ngram_cache_load(llama_ngram_cache & ngram_cache, std::string & filename); // Merge two ngram caches. // ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add. diff --git a/examples/lookup/README.md b/examples/lookup/README.md index 5bfb0de9360414..876d8a94be17cb 100644 --- a/examples/lookup/README.md +++ b/examples/lookup/README.md @@ -1,13 +1,81 @@ # llama.cpp/examples/lookup -Demonstration of Prompt Lookup Decoding +Demonstration of speculative decoding using n-gram lookup. +Initial version was based on https://github.com/apoorvumang/prompt-lookup-decoding . +The current version uses three separate types of "n-gram caches". +Each of these caches maps how frequently a given n-gram is followed by a specific token. +The difference between the caches lies in what data is used to build them: -https://github.com/apoorvumang/prompt-lookup-decoding +* The "context" cache is built using the tokens in the current context of a user generation. +* The "dynamic" cache is built by merging the context caches of previous user generations. +* The "static" cache is built from a large text corpus with no relation to the current context. -The key parameters for lookup decoding are `ngram_min`, `ngram_max` and `n_draft`. The first two determine the size of the ngrams to search for in the prompt for a match. The latter specifies how many subsequent tokens to draft if a match is found. +The tradeoff between these caches lies in relevance to the current context vs. the emount of input data. +When trying to draft a new token using n-gram lookup the algorithm is as follows: -More info: +* Try to draft a suitable token from the context cache. If a static cache is available, use it to validate the draft candidates. This is done by simply multiplying the frequencies of the two caches. +* Try to draft a suitable token from the dynamic cache, validate with static cache if available. +* Try to draft a suitable token from the static cache. -https://github.com/ggerganov/llama.cpp/pull/4484 -https://github.com/ggerganov/llama.cpp/issues/4226 +Only a single token sequence with the most likely token candidates is drafted. +All tokens must pass thresholds for frequency and sample size in order to be drafted. +Relevant command line arguments: + +- `--draft`: maximum number of additional tokens to draft using n-gram lookup. Default: 5. Set to 0 to disable n-gram lookup. **Results are not deterministic with n-gram lookup enabled due to varying batch size.** +- `-lcs FNAME, --lookup-cache-static FNAME`: optional path to static lookup cache to use for n-gram lookup. Created from a large, unspecific text corpus using `lookup-create`. +- `-lcd FNAME, --lookup-cache-dynamic FNAME`: optional path to dynamic lookup cache to use for n-gram lookup. Contains data from previous generations. Automatically created and filled while the server is running but by default discarded on server exit. Setting this argument tries to initialize the dynamic cache from a file and saves it to said file on server shutdown. + +N-gram lookup caches saved to disk are compatible between models as long as they use the same tokenizer +(but for dynamic caches the resulting drafted tokens may be wrong which means there is no speedup). +Furthermore, the data format for both types of caches is the same so they can be used interchangeably (but probably not with good results). + +## Usage Examples + +### `lookup` + +Generation using n-gram lookup: + +``` sh +./lookup --model models/opt/llama_2-7b-q4_0.gguf -ngl 99 --n-predict 256 --ignore-eos --draft 3 --color --prompt "Write a love story about two stars that tragically ends in a type Ia supernova. Use a lot of emotional and dramatic language." +``` + +The `--color` flag highlights the successfully predicted tokens. +The `--lookup-cache-static` and `--lookup-cache-dynamic` arguments can be set to provide static/dynamic caches. + +### `lookup-stats` + +Determine n-gram lookup effectiveness for a given text corpus (similar to `perplexity`): + +``` sh +./lookup-stats --model /opt/models/llama_2-7b-q4_0.gguf --file wikitext-2-raw/wiki.test.raw --draft 3 +``` + +The `--lookup-cache-static` and `--lookup-cache-dynamic` arguments can be set to provide static/dynamic caches. + +### `lookup-create` + +Create a static lookup cache from a text corpus: + +``` sh +./lookup-create --model /opt/models/llama_2-7b-q4_0.gguf --lookup-cache-static wt103-llama_2.lcs --file wikitext-103-raw/wiki.train.raw +``` + +The `--lookup-cache-static` argument must be set to provide the path to which the static lookup cache will be saved. +The tokenizer for which to create the cache is taken from the provided model. + +### `lookup-merge` + +Merge two lookup caches into one: + +``` sh +./lookup-merge cache_1.lcs cache_2.lcs cache_merged.lcs +``` + +Can be used for both static and dynamic lookup caches. + +## More info: + +*https://github.com/ggerganov/llama.cpp/pull/4484 +*https://github.com/ggerganov/llama.cpp/issues/4226 +*https://github.com/ggerganov/llama.cpp/pull/6828 diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index 1c230c9667c715..4d536da6d55105 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -34,7 +34,7 @@ int main(int argc, char ** argv){ llama_ngram_cache ngram_cache; - llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true); + llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp.data(), inp.size(), inp.size(), true); fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str()); llama_ngram_cache_save(ngram_cache, params.lookup_cache_static); diff --git a/examples/lookup/lookup-merge.cpp b/examples/lookup/lookup-merge.cpp index 07c93eb8d057bb..17e33ee09e1f7b 100644 --- a/examples/lookup/lookup-merge.cpp +++ b/examples/lookup/lookup-merge.cpp @@ -33,11 +33,13 @@ int main(int argc, char ** argv){ } fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str()); - llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]); + llama_ngram_cache ngram_cache_merged; + GGML_ASSERT(llama_ngram_cache_load(ngram_cache_merged, args[0])); for (size_t i = 1; i < args.size()-1; ++i) { fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str()); - llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]); + llama_ngram_cache ngram_cache; + GGML_ASSERT(llama_ngram_cache_load(ngram_cache, args[i])); llama_ngram_cache_merge(ngram_cache_merged, ngram_cache); } diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 87ecc0a4f1394e..0cea7eec8a513f 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -46,18 +46,15 @@ int main(int argc, char ** argv){ const int64_t t_start_draft_us = ggml_time_us(); if (!params.lookup_cache_static.empty()) { - try { - ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static); - } catch (std::ifstream::failure const &) { + if(!llama_ngram_cache_load(ngram_cache_static, params.lookup_cache_static)) { fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); exit(1); } } if (!params.lookup_cache_dynamic.empty()) { - try { - ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic); - } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program + // If the dynamic lookup cache doesn't exist it will be created at the end of the program: + llama_ngram_cache_load(ngram_cache_dynamic, params.lookup_cache_dynamic); } t_draft_flat_us += ggml_time_us() - t_start_draft_us; @@ -85,7 +82,9 @@ int main(int argc, char ** argv){ { const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); + llama_ngram_cache_draft( + pseudo_output.data(), pseudo_output.size(), draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); t_draft_us += ggml_time_us() - t_start_draft_us; } @@ -104,7 +103,8 @@ int main(int argc, char ** argv){ { const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); + llama_ngram_cache_update( + ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output.data(), pseudo_output.size(), 1, false); t_draft_us += ggml_time_us() - t_start_draft_us; } } @@ -114,7 +114,8 @@ int main(int argc, char ** argv){ pseudo_output.push_back(inp_slice[pseudo_output.size()]); { const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); + llama_ngram_cache_update( + ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output.data(), pseudo_output.size(), 1, false); t_draft_us += ggml_time_us() - t_start_draft_us; } } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index eebbd00a58e66c..dc2878de4359ed 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -53,21 +53,18 @@ int main(int argc, char ** argv){ { // Fill up context ngram cache with tokens from user input: const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false); + llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp.data(), inp.size(), inp.size(), false); if (!params.lookup_cache_static.empty()) { - try { - ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static); - } catch (std::ifstream::failure const &) { + if(!llama_ngram_cache_load(ngram_cache_static, params.lookup_cache_static)) { fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); exit(1); } } if (!params.lookup_cache_dynamic.empty()) { - try { - ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic); - } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program + // If the dynamic lookup cache doesn't exist it will be created at the end of the program: + llama_ngram_cache_load(ngram_cache_dynamic, params.lookup_cache_dynamic); } t_draft_flat_us += ggml_time_us() - t_start_draft_us; @@ -156,7 +153,7 @@ int main(int argc, char ** argv){ { // Update context ngram cache with the newly accepted token: const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); + llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp.data(), inp.size(), 1, false); t_draft_us += ggml_time_us() - t_start_draft_us; } @@ -182,7 +179,7 @@ int main(int argc, char ** argv){ { // Update context ngram cache with the newly accepted token: const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); + llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp.data(), inp.size(), 1, false); t_draft_us += ggml_time_us() - t_start_draft_us; } break; @@ -204,7 +201,8 @@ int main(int argc, char ** argv){ GGML_ASSERT(draft[0] == inp.back()); const int64_t t_start_draft_us = ggml_time_us(); - llama_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); + llama_ngram_cache_draft( + inp.data(), inp.size(), draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); for (size_t i = 1; i < draft.size(); ++i) { llama_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); diff --git a/examples/server/README.md b/examples/server/README.md index b96a4444a2bd3b..c6b500d5f15cbc 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -8,6 +8,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. * LLM inference of F16 and quantum models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes * Parallel decoding with multi-user support + * Speculative decoding based on n-gram lookup * Continuous batching * Multimodal (wip) * Monitoring endpoints @@ -49,7 +50,7 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/ - `--api-key`: Set an api key for request authorization. By default, the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token. May be used multiple times to enable multiple valid keys. - `--api-key-file`: Path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access. May be used in conjunction with `--api-key`s. - `--embedding`: Enable embedding extraction. Default: disabled -- `-np N`, `--parallel N`: Set the number of slots for process requests. Default: `1` +- `-np N`, `--parallel N`: Set the number of slots for process requests. Default: `1`. **Values > 1 produce nondeterministic results depending on the number of active slots.**. - `-cb`, `--cont-batching`: Enable continuous batching (a.k.a dynamic batching). Default: disabled - `-spf FNAME`, `--system-prompt-file FNAME` Set a file to load a system prompt (initial prompt of all slots). This is useful for chat applications. [See more](#change-system-prompt-on-runtime) - `--mmproj MMPROJ_FILE`: Path to a multimodal projector file for LLaVA. @@ -62,6 +63,9 @@ page cache before using this. See https://github.com/ggerganov/llama.cpp/issues/ - `--chat-template JINJA_TEMPLATE`: Set custom jinja chat template. This parameter accepts a string, not a file name. Default: template taken from model's metadata. We only support [some pre-defined templates](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) - `--log-disable`: Output logs to stdout only, not to `llama.log`. Default: enabled - `--log-format FORMAT`: Define the log output to FORMAT: json or text Default: `json` +- `--draft`: maximum number of additional tokens to draft using n-gram lookup. Default: 5. Set to 0 to disable n-gram lookup. **Results are not deterministic with n-gram lookup enabled due to varying batch size.** +- `-lcs FNAME, --lookup-cache-static FNAME`: optional path to static lookup cache to use for n-gram lookup. Created from a large, unspecific text corpus using `lookup-create`. +- `-lcd FNAME, --lookup-cache-dynamic FNAME`: optional path to dynamic lookup cache to use for n-gram lookup. Contains data from previous generations. Automatically created and filled while the server is running but by default discarded on server exit. Setting this argument tries to initialize the dynamic cache from a file and saves it to said file on server shutdown. **If compiled with `LLAMA_SERVER_SSL=ON`** - `--ssl-key-file FNAME`: path to file a PEM-encoded SSL private key diff --git a/examples/server/bench/bench.py b/examples/server/bench/bench.py index 86c5de101445c1..5ffa847b0c0c60 100644 --- a/examples/server/bench/bench.py +++ b/examples/server/bench/bench.py @@ -45,6 +45,9 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True) parser.add_argument("--scenario", type=str, help="Scenario to run", required=True) parser.add_argument("--duration", type=str, help="Bench scenario", required=True) + parser.add_argument("--draft", type=int, help="Max. number of additional tokens to draft for lookup decoding", required=False, default=5) + parser.add_argument("-lcs", "--lookup-cache-static", type=str, help="Path to optional static lookup cache to use.", required=False, default=None) + parser.add_argument("-lcd", "--lookup-cache-dynamic", type=str, help="Path to optional dynamic lookup cache to use. Will be overwritten upon server shutdown.", required=False, default=None) args = parser.parse_args(args_in) @@ -270,6 +273,11 @@ def start_server_background(args): server_args.append('--metrics') server_args.append('--flash-attn') server_args.extend(['--log-format', "text"]) + server_args.extend(['--draft', args.draft]) + if args.lookup_cache_static is not None: + server_args.extend(['--lookup-cache-static', args.lookup_cache_static]) + if args.lookup_cache_dynamic is not None: + server_args.extend(['--lookup-cache-dynamic', args.lookup_cache_dynamic]) args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") pkwargs = { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f60530cf3db561..cb3f6581273b78 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2,8 +2,9 @@ #include "common.h" #include "json-schema-to-grammar.h" -#include "llama.h" #include "grammar-parser.h" +#include "llama.h" +#include "ngram-cache.h" #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error @@ -163,6 +164,10 @@ struct server_slot { // when a task is submitted, we first tokenize the prompt and store it here std::vector prompt_tokens; + llama_ngram_cache nc_context; + std::vector draft; + std::vector context_tokens; + std::string generated_text; std::vector cache_tokens; std::vector generated_token_probs; @@ -218,6 +223,9 @@ struct server_slot { n_past_se = 0; generated_token_probs.clear(); + + nc_context.clear(); + draft.clear(); } bool has_budget(gpt_params &global_params) { @@ -258,7 +266,7 @@ struct server_slot { } } - json get_formated_timings() const { + json get_formatted_timings() const { return json { {"prompt_n", n_prompt_tokens_processed}, {"prompt_ms", t_prompt_processing}, @@ -423,7 +431,7 @@ struct server_queue { queue_tasks_deferred.push_back(std::move(task)); } - // Get the next id for creating anew task + // Get the next id for creating a new task int get_new_id() { std::unique_lock lock(mutex_tasks); int new_id = id++; @@ -539,7 +547,7 @@ struct server_queue { queue_multitasks.push_back(multi); } - // updatethe remaining subtasks, while appending results to multitask + // update the remaining subtasks, while appending results to multitask void update_multitask(int id_multi, int id_sub, server_task_result & result) { std::lock_guard lock(mutex_tasks); for (auto & multitask : queue_multitasks) { @@ -572,7 +580,7 @@ struct server_response { waiting_task_ids.insert(id_task); } - // when the request is finished, we can remove task associated with it + // when the request is finished, we can remove the task associated with it void remove_waiting_task_id(int id_task) { LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}}); @@ -656,6 +664,10 @@ struct server_context { std::vector slots; json default_generation_settings_for_props; + int32_t n_draft = 3; + llama_ngram_cache nc_dynamic; + llama_ngram_cache nc_static; + server_queue queue_tasks; server_response queue_results; @@ -714,6 +726,8 @@ struct server_context { slot.n_ctx = n_ctx_slot; slot.n_predict = params.n_predict; + slot.context_tokens.resize(n_ctx_slot); + LOG_INFO("new slot", { {"id_slot", slot.id}, {"n_ctx_slot", slot.n_ctx} @@ -744,7 +758,7 @@ struct server_context { slots.push_back(slot); } - default_generation_settings_for_props = get_formated_generation(slots.front()); + default_generation_settings_for_props = get_formatted_generation(slots.front()); default_generation_settings_for_props["seed"] = -1; // the update_slots() logic will always submit a maximum of n_batch tokens @@ -1065,6 +1079,13 @@ struct server_context { for (int i = 0; i < (int)system_tokens.size(); ++i) { llama_batch_add(batch, system_tokens[i], i, { 0 }, false); } + if (n_draft > 0) { + for (auto slot : slots) { + memcpy(slot.context_tokens.data(), system_tokens.data(), system_tokens.size()*sizeof(llama_token)); + llama_ngram_cache_update( + slot.nc_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, system_tokens.data(), system_tokens.size(), system_tokens.size(), false); + } + } const int32_t n_batch = llama_n_batch(ctx); @@ -1245,7 +1266,7 @@ struct server_context { return slot.has_next_token; // continue } - json get_formated_generation(const server_slot & slot) const { + json get_formatted_generation(const server_slot & slot) const { const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); @@ -1367,7 +1388,7 @@ struct server_context { {"model", params.model_alias}, {"tokens_predicted", slot.n_decoded}, {"tokens_evaluated", slot.n_prompt_tokens}, - {"generation_settings", get_formated_generation(slot)}, + {"generation_settings", get_formatted_generation(slot)}, {"prompt", slot.prompt}, {"truncated", slot.truncated}, {"stopped_eos", slot.stopped_eos}, @@ -1375,7 +1396,7 @@ struct server_context { {"stopped_limit", slot.stopped_limit}, {"stopping_word", slot.stopping_word}, {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()} + {"timings", slot.get_formatted_timings()} }; if (slot.sparams.n_probs > 0) { @@ -1573,7 +1594,7 @@ struct server_context { int n_processing_slots = 0; for (server_slot & slot : slots) { - json slot_data = get_formated_generation(slot); + json slot_data = get_formatted_generation(slot); slot_data["id"] = slot.id; slot_data["id_task"] = slot.id_task; slot_data["state"] = slot.state; @@ -1775,6 +1796,9 @@ struct server_context { if (slot.command == SLOT_COMMAND_RELEASE) { slot.state = SLOT_STATE_IDLE; slot.command = SLOT_COMMAND_NONE; + if (n_draft > 0) { + llama_ngram_cache_merge(nc_dynamic, slot.nc_context); + } slot.t_last_used = ggml_time_us(); LOG_INFO("slot released", { @@ -1846,6 +1870,9 @@ struct server_context { llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + for (int j = n_keep; j < slot.n_past - n_discard; ++j) { + slot.context_tokens[j] = slot.context_tokens[j + n_discard]; + } if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -1865,7 +1892,7 @@ struct server_context { // start populating the batch for this iteration llama_batch_clear(batch); - // frist, add sampled tokens from any ongoing sequences + // first, add sampled tokens from any ongoing sequences for (auto & slot : slots) { if (slot.state == SLOT_STATE_IDLE) { continue; @@ -1878,6 +1905,12 @@ struct server_context { // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true); + slot.context_tokens[system_tokens.size() + slot_npast] = slot.sampled; + if (n_draft > 0) { + llama_ngram_cache_update( + slot.nc_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + slot.context_tokens.data(), system_tokens.size() + slot_npast, 1, false); + } slot.n_past += 1; @@ -1885,6 +1918,25 @@ struct server_context { slot.cache_tokens.push_back(slot.sampled); } + if (slot.infill || slot.embedding) { + continue; + } + + const int32_t max_draft = std::min(n_draft, slot.n_ctx - slot.n_past - 1); + if (max_draft <= 0) { + continue; + } + + slot.draft.clear(); + slot.draft.push_back(slot.context_tokens[slot.n_past - 1]); + llama_ngram_cache_draft( + slot.context_tokens.data(), slot.n_past, slot.draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, slot.nc_context, nc_dynamic, nc_static); + + for (int j = 1; j < (int)slot.draft.size(); ++j) { + llama_batch_add(batch, slot.draft[j], system_tokens.size() + slot.n_past, {slot.id + 1}, true); + slot.n_past++; + } + LOG_VERBOSE("slot decode token", { {"id_slot", slot.id}, {"id_task", slot.id_task}, @@ -1905,7 +1957,7 @@ struct server_context { for (auto & slot : slots) { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) { - auto & prompt_tokens = slot.prompt_tokens; + std::vector & prompt_tokens = slot.prompt_tokens; // we haven't tokenized the prompt yet - do it now: if (prompt_tokens.empty()) { @@ -2107,6 +2159,11 @@ struct server_context { } llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false); + slot.context_tokens[system_tokens.size() + slot_npast] = prompt_tokens[slot.n_past]; + if (n_draft > 0) { + llama_ngram_cache_update( + slot.nc_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, slot.context_tokens.data(), slot_npast, 1, false); + } if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -2250,42 +2307,56 @@ struct server_context { continue; // continue loop of slots } - completion_token_output result; - const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + int j = 0; + do { // while (j < std::min(n_batch, (int32_t)slot.draft.size()) && slot.sampled == slot.draft[j]) + completion_token_output result; + const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i + j); - llama_sampling_accept(slot.ctx_sampling, ctx, id, true); + llama_sampling_accept(slot.ctx_sampling, ctx, id, true); - slot.n_decoded += 1; - if (slot.n_decoded == 1) { - slot.t_start_generation = ggml_time_us(); - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } + slot.n_decoded += 1; + if (slot.n_decoded == 1) { + slot.t_start_generation = ggml_time_us(); + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } - llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; - result.tok = id; + llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; + result.tok = id; - const int32_t n_probs = slot.sparams.n_probs; - if (slot.sparams.temp <= 0 && n_probs > 0) { - // for llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &cur_p); - } + const int32_t n_probs = slot.sparams.n_probs; + if (slot.sparams.temp <= 0 && n_probs > 0) { + // for llama_sample_token_greedy we need to sort candidates + llama_sample_softmax(ctx, &cur_p); + } - for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { - result.probs.push_back({ - cur_p.data[i].id, - cur_p.data[i].p - }); - } + for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { + result.probs.push_back({ + cur_p.data[i].id, + cur_p.data[i].p + }); + } - if (!process_token(result, slot)) { - slot.release(); - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); + ++j; + + if (!process_token(result, slot)) { + slot.n_past -= slot.draft.size() - j; + llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1); + slot.draft.clear(); + slot.release(); + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + } + } while (j < std::min(n_batch, (int32_t)slot.draft.size()) && slot.sampled == slot.draft[j]); + + if (j < (int)slot.draft.size()) { + slot.n_past -= slot.draft.size() - j; + llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1); } slot.i_batch = -1; + slot.draft.clear(); } } @@ -2337,6 +2408,11 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co printf(" - distribute: spread execution evenly over all nodes\n"); printf(" - isolate: only spawn threads on CPUs on the node that execution started on\n"); printf(" - numactl: use the CPU map provided my numactl\n"); + printf(" --draft N max. number of additional tokens to draft for n-gram lookup decoding (default: %d)\n", params.n_draft); + printf(" -lcs FNAME, --lookup-cache-static FNAME\n"); + printf(" path to static lookup cache to use for n-gram lookup decoding (not updated by generation)\n"); + printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n"); + printf(" path to dynamic lookup cache to use for n-gram lookup decoding (updated by generation)\n"); if (llama_supports_gpu_offload()) { printf(" -ngl N, --n-gpu-layers N\n"); printf(" number of layers to store in VRAM\n"); @@ -2739,6 +2815,24 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } else { invalid_param = true; break; } } + } else if (arg == "-lcs" || arg == "--lookup-cache-static") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lookup_cache_static = argv[i]; + } else if (arg == "-lcd" || arg == "--lookup-cache-dynamic") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.lookup_cache_dynamic = argv[i]; + } else if (arg == "--draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.n_draft = std::stoi(argv[i]); } else if (arg == "--embedding" || arg == "--embeddings") { params.embedding = true; } else if (arg == "-cb" || arg == "--cont-batching") { @@ -3013,6 +3107,23 @@ int main(int argc, char ** argv) { LOG_INFO("model loaded", {}); + ctx_server.n_draft = params.n_draft; + + if (!params.lookup_cache_static.empty()) { + LOG_INFO("Loading static lookup cache from %s", {params.lookup_cache_static.c_str()}); + if(!llama_ngram_cache_load(ctx_server.nc_static, params.lookup_cache_static)){ + LOG_ERROR("Did not find a lookup cache under %s", {params.lookup_cache_static.c_str()}); + return 1; + } + } + + if (!params.lookup_cache_dynamic.empty()) { + LOG_INFO("Loading dynamic lookup cache from %s", {params.lookup_cache_dynamic.c_str()}); + if(!llama_ngram_cache_load(ctx_server.nc_dynamic, params.lookup_cache_dynamic)){ + LOG_INFO("Did not find a lookup cache under %s . It will be created on server shutdown.", {params.lookup_cache_dynamic.c_str()}); + } + } + const auto model_meta = ctx_server.model_meta(); // if a custom chat template is not supplied, we will use the one that comes with the model (if any) @@ -3813,6 +3924,11 @@ int main(int argc, char ** argv) { svr->stop(); t.join(); + if (!params.lookup_cache_dynamic.empty()) { + LOG_INFO("Saving dynamic lookup cache to %s", {params.lookup_cache_dynamic.c_str()}); + llama_ngram_cache_save(ctx_server.nc_dynamic, params.lookup_cache_dynamic); + } + llama_backend_free(); return 0; diff --git a/examples/server/tests/features/results.feature b/examples/server/tests/features/results.feature index aa0b8d0c648b4c..f1bbcba3e28e77 100644 --- a/examples/server/tests/features/results.feature +++ b/examples/server/tests/features/results.feature @@ -51,7 +51,6 @@ Feature: Results Scenario Outline: consistent results with same seed and varying batch size Given 4 slots And temperature - # And 0 as draft Then the server is starting Then the server is healthy @@ -79,3 +78,32 @@ Feature: Results # and https://github.com/ggerganov/llama.cpp/pull/6122#discussion_r1531405574 . # | 2 | 1.0 | # | 4 | 1.0 | + + Scenario Outline: consistent results with same seed and varying n_draft + Given 0.0 temperature + Given slots + Given 0 as draft + Then the server is starting + Then the server is healthy + + Given 4 prompts "Write a very long story about AI." with seed 42 + And concurrent completion requests + Then the server is busy + Then the server is idle + And all slots are idle + + Given 3 as draft + Then the server is starting + Then the server is healthy + + Given 4 prompts "Write a very long story about AI." with seed 42 + And concurrent completion requests + Then the server is busy + Then the server is idle + And all slots are idle + + Then all predictions are equal + Examples: + | n_slots | + | 1 | + | 2 | diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index b8dbef21d1b768..3a382c93bc2fa0 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -3,6 +3,7 @@ import json import os import re +import signal import socket import subprocess import sys @@ -845,6 +846,8 @@ async def request_completion(prompt, headers=headers, timeout=3600) as response: if expect_api_error is None or not expect_api_error: + if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON' and response.status != 200: + print(f"Unexpected bad HTTP response: {response.status}") assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin return await response.json() @@ -1232,8 +1235,7 @@ def start_server_background(context): server_args.extend(['--ubatch-size', context.n_ubatch]) if context.n_gpu_layer: server_args.extend(['--n-gpu-layers', context.n_gpu_layer]) - if context.draft is not None: - server_args.extend(['--draft', context.draft]) + server_args.extend(['--draft', context.draft if context.draft is not None else 0]) if context.server_continuous_batching: server_args.append('--cont-batching') if context.server_embeddings: @@ -1275,6 +1277,14 @@ def start_server_background(context): 'stdout': subprocess.PIPE, 'stderr': subprocess.PIPE } + + if context.server_process is not None: + if os.name == 'nt': + interrupt = signal.CTRL_C_EVENT + else: + interrupt = signal.SIGINT + context.server_process.send_signal(interrupt) + context.server_process = subprocess.Popen( [str(arg) for arg in [context.server_path, *server_args]], **pkwargs)