From d924e612b2e0ab6d8571c023067a2df33b5f7889 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 22 Apr 2024 23:49:49 +0200 Subject: [PATCH] sampling: separate rng per sampling context --- common/sampling.cpp | 12 +++++++++++- common/sampling.h | 6 ++++++ examples/lookup/lookup-stats.cpp | 1 - examples/lookup/lookup.cpp | 2 +- examples/main/main.cpp | 2 +- examples/server/server.cpp | 2 +- llama.cpp | 9 ++++++--- llama.h | 8 +++++++- 8 files changed, 33 insertions(+), 9 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 45d68b26c2b93f..e78c7706c4b602 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,4 +1,5 @@ #include "sampling.h" +#include struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); @@ -33,6 +34,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); + llama_sampling_set_rng_seed(result, LLAMA_DEFAULT_SEED); + return result; } @@ -62,6 +65,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) { ctx->cur.clear(); } +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { if (dst->grammar) { llama_grammar_free(dst->grammar); @@ -203,7 +213,7 @@ static llama_token llama_sampling_sample_impl( sampler_queue(ctx_main, params, cur_p, min_keep); - id = llama_sample_token(ctx_main, &cur_p); + id = llama_sample_token_with_rng(ctx_main, &cur_p, &ctx_sampling->rng); //{ // const int n_top = 10; diff --git a/common/sampling.h b/common/sampling.h index 639b819ab4fb2c..70794ea10da8c1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -4,6 +4,7 @@ #include "grammar-parser.h" +#include #include #include #include @@ -79,6 +80,8 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector prev; std::vector cur; + + std::mt19937 rng; }; #include "common.h" @@ -93,6 +96,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx); // - reset grammar void llama_sampling_reset(llama_sampling_context * ctx); +// Set the sampler seed +void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); + // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 41b62c2fe9f76b..87ecc0a4f1394e 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -30,7 +30,6 @@ int main(int argc, char ** argv){ // load the model std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_set_rng_seed(ctx, params.seed); GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); // tokenize the prompt diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 9526e898fe7638..aceeeee5082901 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -38,7 +38,6 @@ int main(int argc, char ** argv){ // load the model std::tie(model, ctx) = llama_init_from_gpt_params(params); - llama_set_rng_seed(ctx, params.seed); GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); // tokenize the prompt @@ -108,6 +107,7 @@ int main(int argc, char ** argv){ bool has_eos = false; struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + llama_sampling_set_rng_seed(ctx_sampling, params.seed); std::vector draft; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1180734b9760d2..ca15dba179bb6b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -240,7 +240,6 @@ int main(int argc, char ** argv) { return 1; } session_tokens.resize(n_token_count_out); - llama_set_rng_seed(ctx, params.seed); LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size()); } } @@ -521,6 +520,7 @@ int main(int argc, char ** argv) { } struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + llama_sampling_set_rng_seed(ctx_sampling, params.seed); while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 25bc2963967725..a735d41a6dde4c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1028,7 +1028,7 @@ struct server_context { send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; } - llama_set_rng_seed(ctx, slot.params.seed); + llama_sampling_set_rng_seed(slot.ctx_sampling, slot.params.seed); } slot.command = SLOT_COMMAND_LOAD_PROMPT; diff --git a/llama.cpp b/llama.cpp index a25d115c1d82af..ee1c9280528286 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13478,7 +13478,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da return result; } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { +llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, void * rng) { GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); @@ -13491,8 +13491,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra } std::discrete_distribution<> dist(probs.begin(), probs.end()); - auto & rng = ctx->rng; - int idx = dist(rng); + int idx = dist(*((std::mt19937 *) rng)); llama_token result = candidates->data[idx].id; @@ -13501,6 +13500,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra return result; } +llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { + return llama_sample_token_with_rng(ctx, candidates, &ctx->rng); +} + void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); diff --git a/llama.h b/llama.h index 4effca42cc65de..8f29cf752b70b9 100644 --- a/llama.h +++ b/llama.h @@ -987,7 +987,13 @@ extern "C" { struct llama_context * ctx, llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probabilities. + /// @details Randomly selects a token from the candidates based on their probabilities using a given pointer to a std::mt19937. + LLAMA_API llama_token llama_sample_token_with_rng( + struct llama_context * ctx, + llama_token_data_array * candidates, + void * rng); + + /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. LLAMA_API llama_token llama_sample_token( struct llama_context * ctx, llama_token_data_array * candidates);