Skip to content

Commit

Permalink
dynamic lookup cache, code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed Mar 10, 2024
1 parent 9ee3d63 commit a908794
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 174 deletions.
260 changes: 157 additions & 103 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstring>
#include <ctime>
#include <fstream>
Expand All @@ -12,6 +13,7 @@
#include <regex>
#include <sstream>
#include <string>
#include <system_error>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -720,6 +722,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.lookup_cache_static = argv[i];
} else if (arg == "-lcd" || arg == "--lookup-cache-dynamic") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.lookup_cache_dynamic = argv[i];
} else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1100,7 +1108,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n");
printf(" -lcs FNAME, --lookup-cache-static FNAME\n");
printf(" path to static lookup cache to use for lookup decoding\n");
printf(" path to static lookup cache to use for lookup decoding (not updated by generation)\n");
printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n");
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
Expand Down Expand Up @@ -1860,15 +1870,12 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
printf("\n=== Done dumping\n");
}

void llama_ngram_cache_update(std::vector<llama_ngram_cache> & ncs, int ngram_min,
void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
std::vector<llama_token> & inp, int nnew, bool print_progress) {
const int64_t t_start_ms = ggml_time_ms();
const int ngram_max = ngram_min + ncs.size()-1;
const int inp_size = inp.size();

for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
llama_ngram_cache & nc = ncs[ngram_size - ngram_min];

const int i_start = std::max(inp_size - nnew, ngram_size);
for (int i = i_start; i < inp_size; ++i) {
const int ngram_start = i - ngram_size;
Expand All @@ -1880,11 +1887,11 @@ void llama_ngram_cache_update(std::vector<llama_ngram_cache> & ncs, int ngram_mi
}
const llama_token token = inp[i];

llama_ngram_cache::iterator part_it = nc.find(ngram);
if (part_it == nc.end()) {
llama_ngram_cache::iterator part_it = ngram_cache.find(ngram);
if (part_it == ngram_cache.end()) {
llama_ngram_cache_part part;
part.emplace(token, 1);
nc.emplace(ngram, part);
ngram_cache.emplace(ngram, part);
} else {
llama_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
if (token_count_it == part_it->second.end()) {
Expand All @@ -1911,128 +1918,150 @@ static llama_token get_token(const std::vector<llama_token> & inp, const std::ve
};

// If sample size or percentage in context are below these thresholds the draft is aborted early:
constexpr int draft_min_sample_size[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
constexpr int draft_min_percent[LLAMA_NGRAM_MAX] = {50, 50, 50, 50};
constexpr int draft_min_sample_size_t1[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
constexpr int draft_min_percent_t1[LLAMA_NGRAM_MAX] = {66, 50, 50, 50};
constexpr int draft_min_sample_size_t2[LLAMA_NGRAM_MAX] = { 4, 3, 2, 2};
constexpr int draft_min_percent_t2[LLAMA_NGRAM_MAX] = {75, 66, 66, 66};

void llama_ngram_cache_draft(
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft,
std::vector<llama_ngram_cache> & ncs_t1, int ngram_min, llama_ngram_cache & nc_t2
) {
const int inp_size = inp.size();
const int ngram_max = ngram_min + ncs_t1.size()-1;
static llama_token try_draft(llama_ngram_cache & nc_primary, const uint64_t ngram_primary) {
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
if (part_primary_it == nc_primary.end()) {
return -1;
}
const llama_ngram_cache_part part_primary = part_primary_it->second;

while ((int) draft.size()-1 < n_draft) {
bool draft_success = false;
int max_count_primary = 0;
int sum_count_primary = 0;
llama_token max_token = -1;

const int ngram_start_t2 = inp_size-2 + draft.size()-1;
uint64_t ngram_t2 = get_token(inp, draft, ngram_start_t2);
for (int j = ngram_start_t2+1; j < ngram_start_t2 + 2; ++j) {
const uint64_t token = get_token(inp, draft, j);
ngram_t2 <<= 16;
ngram_t2 |= token;
}
llama_ngram_cache::iterator part_t2_it = nc_t2.find(ngram_t2);
llama_ngram_cache_part part_t2;
if (part_t2_it != nc_t2.end()) {
part_t2 = part_t2_it->second;
for (std::pair<llama_token, int> token_count_primary : part_primary) {
const llama_token token = token_count_primary.first;
const int32_t count_primary = token_count_primary.second;

if (count_primary > max_count_primary) {
max_token = token;
max_count_primary = count_primary;
}
sum_count_primary += count_primary;
}

for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
if (ngram_size > inp_size) {
continue;
}
if (sum_count_primary < draft_min_sample_size_t1[2-1]) {
return -1;
}
if (100*max_count_primary < draft_min_percent_t1[2-1]*sum_count_primary) {
return -1;
}
return max_token;
}

llama_ngram_cache & nc_t1 = ncs_t1[ngram_size - ngram_min];
static llama_token try_draft(
llama_ngram_cache & nc_primary, const std::vector<uint64_t> & ngrams_primary, llama_ngram_cache_part & part_validate,
const int * min_sample_size, const int * min_percent) {

const int ngram_start_t1 = inp_size-ngram_size + draft.size()-1;
uint64_t ngram_t1 = get_token(inp, draft, ngram_start_t1);
for (int j = ngram_start_t1+1; j < ngram_start_t1 + ngram_size; ++j) {
const uint64_t token = get_token(inp, draft, j);
ngram_t1 <<= 16;
ngram_t1 |= token;
}
llama_token drafted_token = -1;

llama_ngram_cache::iterator part_t1_it = nc_t1.find(ngram_t1);
if (part_t1_it == nc_t1.end()) {
continue;
}
const llama_ngram_cache_part part_t1 = part_t1_it->second;
for (int i = ngrams_primary.size()-1; i >= 0 && drafted_token == -1; --i) {
const uint64_t ngram_primary = ngrams_primary[i];

int max_count_t1 = 0;
int max_count_t2 = 0;
int sum_count_t1 = 0;
llama_token max_token = -1;
llama_ngram_cache::iterator part_primary_it = nc_primary.find(ngram_primary);
if (part_primary_it == nc_primary.end()) {
continue;
}
const llama_ngram_cache_part part_primary = part_primary_it->second;

for (std::pair<llama_token, int> token_count_t1 : part_t1) {
const llama_token token = token_count_t1.first;
int max_count_primary = 0;
int max_count_validate = 0;
int sum_count_primary = 0;
llama_token max_token = -1;

llama_ngram_cache_part::iterator token_count_t2_it = part_t2.find(token);
const int32_t count_t1 = token_count_t1.second;
const int32_t count_t2 = token_count_t2_it != part_t2.end() ? 100*token_count_t2_it->second : 1;
for (std::pair<llama_token, int> token_count_primary : part_primary) {
const llama_token token = token_count_primary.first;

if (count_t1*count_t2 > max_count_t1*max_count_t2) {
max_token = token;
max_count_t1 = count_t1;
max_count_t2 = count_t2;
}
sum_count_t1 += count_t1;
}
// Skip this candidate if the sample size is too low:
if (sum_count_t1 < draft_min_sample_size[ngram_size-1]) {
continue;
}
// skip this candidate if the empirically most likely token following this token is not likely enough:
if (100*max_count_t1 < draft_min_percent[ngram_size-1]*sum_count_t1) {
continue;
llama_ngram_cache_part::iterator token_count_validate_it = part_validate.find(token);

const int32_t count_primary = token_count_primary.second;
const int32_t count_validate = token_count_validate_it != part_validate.end() ? 100*token_count_validate_it->second : 1;

if (count_primary*count_validate > max_count_primary*max_count_validate) {
max_token = token;
max_count_primary = count_primary;
max_count_validate = count_validate;
}
sum_count_primary += count_primary;
}

LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count_t1);
draft.push_back(max_token);
draft_success = true;
break;
if (sum_count_primary < min_sample_size[i]) {
continue;
}
if (100*max_count_primary < min_percent[i]*sum_count_primary) {
continue;;
}
drafted_token = max_token;
}

if (!draft_success) {
int max_count_t2 = 0;
int sum_count_t2 = 0;
llama_token max_token = -1;
return drafted_token;
}

for (std::pair<llama_token, int> token_count_t2 : part_t2) {
const llama_token token = token_count_t2.first;
const int32_t count_t2 = token_count_t2.second;
void llama_ngram_cache_draft(
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
llama_ngram_cache & nc_t1, llama_ngram_cache & nc_t2, llama_ngram_cache & nc_t3
) {
const int inp_size = inp.size();

if (count_t2 > max_count_t2) {
max_token = token;
max_count_t2 = count_t2;
}
sum_count_t2 += count_t2;
}
if (inp_size < 2) {
return;
}

// Skip this candidate if the sample size is too low:
if (sum_count_t2 < draft_min_sample_size[2-1]) {
break;
}
// skip this candidate if the empirically most likely token following this token is not likely enough:
if (100*max_count_t2 < draft_min_percent[2-1]*sum_count_t2) {
break;
}
while ((int) draft.size()-1 < n_draft) {
llama_token drafted_token = -1;

LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count_t2);
draft.push_back(max_token);
draft_success = true;
break;
const int ngram_start_t23 = inp_size-2 + draft.size()-1;
uint64_t ngram_t23 = get_token(inp, draft, ngram_start_t23);
for (int j = ngram_start_t23+1; j < ngram_start_t23 + 2; ++j) {
const uint64_t token = get_token(inp, draft, j);
ngram_t23 <<= 16;
ngram_t23 |= token;
}
llama_ngram_cache::iterator part_t3_it = nc_t3.find(ngram_t23);
llama_ngram_cache_part part_t3;
if (part_t3_it != nc_t3.end()) {
part_t3 = part_t3_it->second;
}

if (!draft_success) {
std::vector<uint64_t> ngrams_t12;
for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
const int ngram_start_t12 = inp_size-ngram_size + draft.size()-1;
uint64_t ngram_t12 = get_token(inp, draft, ngram_start_t12);
for (int j = ngram_start_t12+1; j < ngram_start_t12 + ngram_size; ++j) {
const uint64_t token = get_token(inp, draft, j);
ngram_t12 <<= 16;
ngram_t12 |= token;
}
ngrams_t12.push_back(ngram_t12);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_t1, ngrams_t12, part_t3, draft_min_sample_size_t1, draft_min_percent_t1);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_t2, ngrams_t12, part_t3, draft_min_sample_size_t2, draft_min_percent_t2);
}
if (drafted_token == -1) {
drafted_token = try_draft(nc_t3, ngram_t23);
}

if (drafted_token == -1) {
break;
}

LOG(" - draft candidate: token=%d\n", drafted_token);
draft.push_back(drafted_token);
}
};

void llama_ngram_cache_save(std::vector<llama_ngram_cache> & ngram_cache, std::string & filename) {
void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename) {
GGML_ASSERT(ngram_cache.size() == 1);
std::ofstream file_out(filename, std::ios::binary);
for (std::pair<uint64_t, llama_ngram_cache_part> item : ngram_cache[0]) {
for (std::pair<uint64_t, llama_ngram_cache_part> item : ngram_cache) {
const uint64_t ngram = item.first;
llama_ngram_cache_part token_counts = item.second;
GGML_ASSERT(!token_counts.empty());
Expand All @@ -2054,8 +2083,7 @@ void llama_ngram_cache_save(std::vector<llama_ngram_cache> & ngram_cache, std::s
llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
std::ifstream hashmap_file(filename, std::ios::binary);
if (!hashmap_file) {
fprintf(stderr, "error: failed to open file '%s'\n", filename.c_str());
exit(1);
throw std::system_error();
}
llama_ngram_cache ngram_cache;

Expand Down Expand Up @@ -2084,3 +2112,29 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {

return ngram_cache;
}

void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {
for (std::pair<uint64_t, llama_ngram_cache_part> ngram_part : ngram_cache_add) {
const uint64_t ngram = ngram_part.first;
llama_ngram_cache_part part = ngram_part.second;

llama_ngram_cache::iterator part_merged_it = ngram_cache_target.find(ngram);
if (part_merged_it == ngram_cache_target.end()) {
ngram_cache_target.emplace(ngram, part);
continue;
}

for (std::pair<llama_token, int32_t> token_count : part) {
const llama_token token = token_count.first;
const int32_t count = token_count.second;

llama_ngram_cache_part::iterator token_count_merged_it = part_merged_it->second.find(token);
if (token_count_merged_it == part_merged_it->second.end()) {
part_merged_it->second.emplace(token, count);
continue;
}

token_count_merged_it->second += count;
}
}
}
Loading

0 comments on commit a908794

Please sign in to comment.