-
Notifications
You must be signed in to change notification settings - Fork 11k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
lookup: evaluation tools, use corpus/previous gens
- Loading branch information
1 parent
f9c7ba3
commit 8dcd771
Showing
10 changed files
with
396 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include "ggml.h" | ||
#include "llama.h" | ||
#include "common.h" | ||
#include "ngram-cache.h" | ||
|
||
#include <cstdint> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
int main(int argc, char ** argv){ | ||
gpt_params params; | ||
|
||
if (!gpt_params_parse(argc, argv, params)) { | ||
return 1; | ||
} | ||
// init llama.cpp | ||
llama_backend_init(); | ||
llama_numa_init(params.numa); | ||
|
||
llama_model * model = NULL; | ||
llama_context * ctx = NULL; | ||
|
||
// load the model | ||
std::tie(model, ctx) = llama_init_from_gpt_params(params); | ||
GGML_ASSERT(model != nullptr); | ||
|
||
// tokenize the prompt | ||
const bool add_bos = llama_should_add_bos_token(model); | ||
|
||
std::vector<llama_token> inp; | ||
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); | ||
fprintf(stderr, "%s: tokenization done\n", __func__); | ||
|
||
|
||
llama_ngram_cache ngram_cache; | ||
llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true); | ||
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str()); | ||
|
||
llama_ngram_cache_save(ngram_cache, params.lookup_cache_static); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#include "ggml.h" | ||
#include "llama.h" | ||
#include "common.h" | ||
#include "ngram-cache.h" | ||
|
||
#include <cstdint> | ||
#include <cstdio> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <string> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
static void print_usage() { | ||
fprintf(stderr, "Merges multiple lookup cache files into a single one.\n"); | ||
fprintf(stderr, "Usage: lookup-merge [--help] lookup_part_1.bin lookup_part_2.bin ... lookup_merged.bin\n"); | ||
} | ||
|
||
int main(int argc, char ** argv){ | ||
if (argc < 3) { | ||
print_usage(); | ||
exit(1); | ||
} | ||
|
||
std::vector<std::string> args; | ||
args.resize(argc-1); | ||
for (int i = 0; i < argc-1; ++i) { | ||
args[i] = argv[i+1]; | ||
if (args[i] == "-h" || args[i] == "--help") { | ||
print_usage(); | ||
exit(0); | ||
} | ||
} | ||
|
||
fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str()); | ||
llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]); | ||
|
||
for (size_t i = 1; i < args.size()-1; ++i) { | ||
fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str()); | ||
llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]); | ||
|
||
llama_ngram_cache_merge(ngram_cache_merged, ngram_cache); | ||
} | ||
|
||
fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str()); | ||
llama_ngram_cache_save(ngram_cache_merged, args.back()); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
#include "ggml.h" | ||
#include "common.h" | ||
#include "llama.h" | ||
#include "log.h" | ||
#include "ngram-cache.h" | ||
|
||
#include <cmath> | ||
#include <cstdint> | ||
#include <cstdio> | ||
#include <fstream> | ||
#include <string> | ||
#include <vector> | ||
#include <unordered_map> | ||
|
||
int main(int argc, char ** argv){ | ||
gpt_params params; | ||
|
||
if (!gpt_params_parse(argc, argv, params)) { | ||
return 1; | ||
} | ||
|
||
const int n_draft = params.n_draft; | ||
|
||
// init llama.cpp | ||
llama_backend_init(); | ||
llama_numa_init(params.numa); | ||
|
||
llama_model * model = NULL; | ||
llama_context * ctx = NULL; | ||
|
||
// load the model | ||
std::tie(model, ctx) = llama_init_from_gpt_params(params); | ||
llama_set_rng_seed(ctx, params.seed); | ||
GGML_ASSERT(llama_n_vocab(model) < (1 << 16)); | ||
|
||
// tokenize the prompt | ||
const bool add_bos = llama_should_add_bos_token(model); | ||
LOG("add_bos tgt: %d\n", add_bos); | ||
|
||
std::vector<llama_token> inp; | ||
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); | ||
|
||
llama_ngram_cache ngram_cache_context; | ||
llama_ngram_cache ngram_cache_dynamic; | ||
llama_ngram_cache ngram_cache_static; | ||
int64_t t_draft_flat_us = 0; | ||
int64_t t_draft_us = 0; | ||
|
||
{ | ||
const int64_t t_start_draft_us = ggml_time_us(); | ||
|
||
if (!params.lookup_cache_static.empty()) { | ||
try { | ||
ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static); | ||
} catch (std::system_error const &) { | ||
fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); | ||
exit(1); | ||
} | ||
} | ||
|
||
if (!params.lookup_cache_dynamic.empty()) { | ||
try { | ||
ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic); | ||
} catch (std::system_error const &) {} // if the file does not exist it will simply be created at the end of the program | ||
} | ||
|
||
t_draft_flat_us += ggml_time_us() - t_start_draft_us; | ||
} | ||
|
||
const int n_input = inp.size(); | ||
const int n_ctx = params.n_ctx; | ||
|
||
int n_drafted = 0; | ||
int n_accept = 0; | ||
|
||
const int64_t t_start_ms = ggml_time_ms(); | ||
|
||
// Iterate over input tokens in chunks of size n_ctx. | ||
// Each chunk is treated as if a sequential generation but with pre-determined tokens to ensure reproducibility. | ||
for (int i_start = 0; i_start + n_ctx < n_input; i_start += n_ctx) { | ||
const std::vector<llama_token> inp_slice(inp.begin() + i_start, inp.begin() + i_start + n_ctx); | ||
std::vector<llama_token> pseudo_output; | ||
pseudo_output.push_back(inp_slice[0]); | ||
|
||
while ((int) pseudo_output.size() < n_ctx) { | ||
// Simulate drafting and decoding from draft: | ||
std::vector<llama_token> draft; | ||
draft.push_back(pseudo_output.back()); | ||
|
||
{ | ||
const int64_t t_start_draft_us = ggml_time_us(); | ||
llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); | ||
t_draft_us += ggml_time_us() - t_start_draft_us; | ||
} | ||
|
||
n_drafted += draft.size() - 1; | ||
|
||
for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) { | ||
const llama_token ground_truth = inp_slice[pseudo_output.size()]; | ||
const llama_token drafted = draft[j]; | ||
|
||
if (ground_truth != drafted) { | ||
break; | ||
} | ||
|
||
++n_accept; | ||
pseudo_output.push_back(ground_truth); | ||
|
||
{ | ||
const int64_t t_start_draft_us = ggml_time_us(); | ||
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); | ||
t_draft_us += ggml_time_us() - t_start_draft_us; | ||
} | ||
} | ||
|
||
// After each simulated batch decoding simulate the sampling of a single token: | ||
if ((int) pseudo_output.size() < n_ctx) { | ||
pseudo_output.push_back(inp_slice[pseudo_output.size()]); | ||
{ | ||
const int64_t t_start_draft_us = ggml_time_us(); | ||
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); | ||
t_draft_us += ggml_time_us() - t_start_draft_us; | ||
} | ||
} | ||
|
||
draft.erase(draft.begin()); | ||
|
||
} | ||
if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) { | ||
const int64_t t_now_ms = ggml_time_ms(); | ||
const int64_t eta_ms = (n_input - i_start) * (t_now_ms - t_start_ms) / i_start; | ||
const int64_t eta_min = eta_ms / (60*1000); | ||
const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000; | ||
|
||
LOG_TEE("%d/%d done, ETA: %02ld:%02ld\n", i_start, n_input, eta_min, eta_s); | ||
} | ||
|
||
// After each chunk, update the dynamic ngram cache with the context ngram cache: | ||
llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context); | ||
ngram_cache_context.clear(); | ||
} | ||
|
||
LOG_TEE("\n"); | ||
|
||
LOG_TEE("\n"); | ||
LOG_TEE("n_draft = %d\n", n_draft); | ||
LOG_TEE("n_predict = %d\n", n_input - n_input % n_ctx); | ||
LOG_TEE("n_drafted = %d\n", n_drafted); | ||
LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3); | ||
LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n", | ||
t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us)); | ||
LOG_TEE("n_accept = %d\n", n_accept); | ||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); | ||
|
||
llama_free(ctx); | ||
llama_free_model(model); | ||
|
||
llama_backend_free(); | ||
|
||
fprintf(stderr, "\n\n"); | ||
|
||
return 0; | ||
} |
Oops, something went wrong.