Skip to content

Commit

Permalink
fix save-load-state example
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Oct 12, 2024
1 parent 7264596 commit 6395174
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,16 @@ int main(int argc, char ** argv) {
// tokenize prompt
auto tokens = common_tokenize(ctx, params.prompt, true);

// prepare the batch
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
}
batch.logits[batch.n_tokens - 1] = true; // generate next token

// evaluate prompt
llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()));
n_past += tokens.size();
llama_decode(ctx, batch);
n_past += batch.n_tokens;

// save state (rng, logits, embedding and kv_cache) to file
{
Expand All @@ -77,8 +84,12 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result0 += next_token_str;

if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) {
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);

if (llama_decode(ctx, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);
return 1;
Expand Down Expand Up @@ -133,8 +144,12 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result1 += next_token_str;

if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) {
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {0}, true);

if (llama_decode(ctx2, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_free(ctx2);
llama_free_model(model);
return 1;
Expand Down Expand Up @@ -221,8 +236,12 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str.c_str());
result2 += next_token_str;

if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) {
common_batch_clear(batch);
common_batch_add(batch, next_token, n_past, {1}, true);

if (llama_decode(ctx3, batch)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_batch_free(batch);
llama_free(ctx3);
llama_free_model(model);
return 1;
Expand All @@ -236,6 +255,7 @@ int main(int argc, char ** argv) {
llama_sampler_free(smpl2);
llama_sampler_free(smpl3);

llama_batch_free(batch);
llama_free(ctx3);
llama_free_model(model);

Expand Down

0 comments on commit 6395174

Please sign in to comment.