From 6395174a5410ba6634b8256e6cbc4c4625e46907 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 12 Oct 2024 23:09:05 +0200 Subject: [PATCH] fix save-load-state example --- examples/save-load-state/save-load-state.cpp | 30 ++++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index f72832d9cc672..5f60a86cbc2d2 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,9 +48,16 @@ int main(int argc, char ** argv) { // tokenize prompt auto tokens = common_tokenize(ctx, params.prompt, true); + // prepare the batch + llama_batch batch = llama_batch_init(tokens.size(), 0, 1); + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], i, {0}, false); + } + batch.logits[batch.n_tokens - 1] = true; // generate next token + // evaluate prompt - llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size())); - n_past += tokens.size(); + llama_decode(ctx, batch); + n_past += batch.n_tokens; // save state (rng, logits, embedding and kv_cache) to file { @@ -77,8 +84,12 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result0 += next_token_str; - if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) { + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {0}, true); + + if (llama_decode(ctx, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_batch_free(batch); llama_free(ctx); llama_free_model(model); return 1; @@ -133,8 +144,12 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result1 += next_token_str; - if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) { + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {0}, true); + + if (llama_decode(ctx2, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_batch_free(batch); llama_free(ctx2); llama_free_model(model); return 1; @@ -221,8 +236,12 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result2 += next_token_str; - if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) { + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {1}, true); + + if (llama_decode(ctx3, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_batch_free(batch); llama_free(ctx3); llama_free_model(model); return 1; @@ -236,6 +255,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl2); llama_sampler_free(smpl3); + llama_batch_free(batch); llama_free(ctx3); llama_free_model(model);