Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check the full vocab for grammar only if necessary #4306

Merged
merged 8 commits into from
Dec 23, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int idx,
bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
Expand All @@ -128,8 +129,17 @@ llama_token llama_sampling_sample(

llama_token id = 0;

// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);

// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;

if (!is_resampling) {
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
}

// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
Expand Down Expand Up @@ -165,7 +175,8 @@ llama_token llama_sampling_sample(
}
}

if (ctx_sampling->grammar != NULL) {
// If we are in the resampling phase, apply grammar checks before sampling logic
if (is_resampling && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}

Expand Down Expand Up @@ -212,6 +223,29 @@ llama_token llama_sampling_sample(
}
}

if (ctx_sampling->grammar != NULL && !is_resampling) {
// Create an array with a single token data element for the sampled id
llama_token_data single_token_data = {id, logits[id], 0.0f};
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };

// Apply grammar constraints to the single token
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);

// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
Comment on lines +268 to +275
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we could probably expose something more straightforward to check this if we want. It would probably be like llama_grammar_accept_token but returning a bool instead of throwing.


// If the token is not valid according to the grammar, perform resampling
if (!is_valid) {
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());

// Restore logits from the copy
std::copy(original_logits.begin(), original_logits.end(), logits);

// Recursively call llama_sampling_sample to resample with the grammar checks applied first
return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
}
}

return id;
}

Expand Down
9 changes: 5 additions & 4 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,11 @@ std::string llama_sampling_print(const llama_sampling_params & params);
// - candidates: vector of candidate tokens
//
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx,
bool is_resampling = false); // Add the new parameter with default value
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should hide the is_resampling argument behind the public API - it's an implementation detail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment here is also redundant, technically.
Also, I mentioned this in another comment, but I think the candidates size isn't getting properly resized right now for the '2nd pass' of sampling if Top K, Min P, etc truncation samplers are used because I only restore the scores rather than the size. How might I do that?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, not sure what you mean. The candidates array is part of the sampling context and it is cleared and populated on each call. The proposed code should work - does it not work?

Copy link
Contributor Author

@kalomaze kalomaze Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, not sure what you mean. The candidates array is part of the sampling context and it is cleared and populated on each call. The proposed code should work - does it not work?

It does seem to work, but my concern was that it was maybe not working because I wasn't sure if recursively calling the function would accomplish that if we are resampling the same token. If my implementation seems correct (it seems to work in my brief testing), then this is probably mergeable then.


void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
Expand Down
2 changes: 1 addition & 1 deletion examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ int main(int argc, char ** argv) {

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {

const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down
2 changes: 1 addition & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}

const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance, 0, false);

llama_sampling_accept(ctx_sampling, ctx, id, true);

Expand Down
Loading