Skip to content

Commit

Permalink
sampling: separate rng per sampling context
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Apr 22, 2024
1 parent b1a1891 commit d924e61
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 9 deletions.
12 changes: 11 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "sampling.h"
#include <random>

struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();
Expand Down Expand Up @@ -33,6 +34,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_

result->prev.resize(params.n_prev);

llama_sampling_set_rng_seed(result, LLAMA_DEFAULT_SEED);

return result;
}

Expand Down Expand Up @@ -62,6 +65,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
ctx->cur.clear();
}

void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
seed = time(NULL);
}
ctx->rng.seed(seed);
}

void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->grammar) {
llama_grammar_free(dst->grammar);
Expand Down Expand Up @@ -203,7 +213,7 @@ static llama_token llama_sampling_sample_impl(

sampler_queue(ctx_main, params, cur_p, min_keep);

id = llama_sample_token(ctx_main, &cur_p);
id = llama_sample_token_with_rng(ctx_main, &cur_p, &ctx_sampling->rng);

//{
// const int n_top = 10;
Expand Down
6 changes: 6 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "grammar-parser.h"

#include <random>
#include <string>
#include <vector>
#include <unordered_map>
Expand Down Expand Up @@ -79,6 +80,8 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;

std::mt19937 rng;
};

#include "common.h"
Expand All @@ -93,6 +96,9 @@ void llama_sampling_free(struct llama_sampling_context * ctx);
// - reset grammar
void llama_sampling_reset(llama_sampling_context * ctx);

// Set the sampler seed
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed);

// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);

Expand Down
1 change: 0 additions & 1 deletion examples/lookup/lookup-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ int main(int argc, char ** argv){

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));

// tokenize the prompt
Expand Down
2 changes: 1 addition & 1 deletion examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ int main(int argc, char ** argv){

// load the model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
llama_set_rng_seed(ctx, params.seed);
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));

// tokenize the prompt
Expand Down Expand Up @@ -108,6 +107,7 @@ int main(int argc, char ** argv){
bool has_eos = false;

struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
llama_sampling_set_rng_seed(ctx_sampling, params.seed);

std::vector<llama_token> draft;

Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,6 @@ int main(int argc, char ** argv) {
return 1;
}
session_tokens.resize(n_token_count_out);
llama_set_rng_seed(ctx, params.seed);
LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int)session_tokens.size());
}
}
Expand Down Expand Up @@ -521,6 +520,7 @@ int main(int argc, char ** argv) {
}

struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
llama_sampling_set_rng_seed(ctx_sampling, params.seed);

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1028,7 +1028,7 @@ struct server_context {
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
}
llama_set_rng_seed(ctx, slot.params.seed);
llama_sampling_set_rng_seed(slot.ctx_sampling, slot.params.seed);
}

slot.command = SLOT_COMMAND_LOAD_PROMPT;
Expand Down
9 changes: 6 additions & 3 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13478,7 +13478,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
return result;
}

llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, void * rng) {
GGML_ASSERT(ctx);

const int64_t t_start_sample_us = ggml_time_us();
Expand All @@ -13491,8 +13491,7 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
}

std::discrete_distribution<> dist(probs.begin(), probs.end());
auto & rng = ctx->rng;
int idx = dist(rng);
int idx = dist(*((std::mt19937 *) rng));

llama_token result = candidates->data[idx].id;

Expand All @@ -13501,6 +13500,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
return result;
}

llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
return llama_sample_token_with_rng(ctx, candidates, &ctx->rng);
}

void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();

Expand Down
8 changes: 7 additions & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,13 @@ extern "C" {
struct llama_context * ctx,
llama_token_data_array * candidates);

/// @details Randomly selects a token from the candidates based on their probabilities.
/// @details Randomly selects a token from the candidates based on their probabilities using a given pointer to a std::mt19937.
LLAMA_API llama_token llama_sample_token_with_rng(
struct llama_context * ctx,
llama_token_data_array * candidates,
void * rng);

/// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
LLAMA_API llama_token llama_sample_token(
struct llama_context * ctx,
llama_token_data_array * candidates);
Expand Down

0 comments on commit d924e61

Please sign in to comment.