Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : refactor k-shift implementation + initial defragmentation #5691

Merged
merged 25 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1a99981
llama : refactor k-shift implementation
ggerganov Feb 23, 2024
dd39219
llama : rename llama_kv_cache_seq_shift to llama_kv_cache_seq_add
ggerganov Feb 24, 2024
89b2a43
llama : cont k-shift refactoring + normalize type names
ggerganov Feb 24, 2024
2b9a9bf
minor : fix MPI builds
ggerganov Feb 24, 2024
5f5b1b5
llama : reuse n_rot from the build context
ggerganov Feb 24, 2024
42ddf48
llama : revert enum name changes from this PR
ggerganov Feb 24, 2024
31e1ec9
llama : update llama_rope_type
ggerganov Feb 24, 2024
decea31
llama : add comment about rope values
ggerganov Feb 24, 2024
8f9fe6d
llama : fix build
ggerganov Feb 24, 2024
79e2761
passkey : apply kv cache updates explicitly
ggerganov Feb 24, 2024
18da970
llama : change name to llama_kv_cache_update()
ggerganov Feb 24, 2024
b75ec64
llama : add llama_kv_cache_seq_pos_max()
ggerganov Feb 24, 2024
032ff85
passkey : fix llama_kv_cache_seq_pos_max() usage
ggerganov Feb 25, 2024
715a343
llama : some llama_kv_cell simplifications
ggerganov Feb 25, 2024
fdfa5bc
llama : add llama_kv_cache_compress (EXPERIMENTAL)
ggerganov Feb 25, 2024
0d6f873
Merge branch 'master' into gg/refactor-k-shift
ggerganov Feb 25, 2024
9ec749d
llama : add alternative KV cache merging (EXPERIMENTAL)
ggerganov Feb 25, 2024
65f21ec
llama : add llama_kv_cache_defrag
ggerganov Feb 25, 2024
d141c74
Merge branch 'master' into gg/refactor-k-shift
ggerganov Feb 25, 2024
1b6aeb8
llama : comments
ggerganov Feb 25, 2024
2d7203b
llama : remove llama_kv_cache_compress
ggerganov Feb 25, 2024
65323bc
llama : defragment via non-overlapping moves
ggerganov Feb 25, 2024
4eaaace
llama : ggml_graph based defrag implementation
ggerganov Feb 25, 2024
0b72ded
llama : switch the loop order in build_defrag
ggerganov Feb 25, 2024
5a122c2
llama : add comments
ggerganov Feb 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);

n_past -= n_discard;

Expand Down
10 changes: 5 additions & 5 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);

n_past -= n_discard;

Expand All @@ -576,9 +576,9 @@ int main(int argc, char ** argv) {
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);

llama_kv_cache_seq_shift(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);

n_past -= bd;

Expand Down
25 changes: 15 additions & 10 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ int main(int argc, char ** argv) {
const int n_batch = ctx_params.n_batch;
const int n_batch_grp = ctx_params.n_batch/n_grp;

LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos);

// print the prompt token-by-token

Expand All @@ -146,10 +146,11 @@ int main(int argc, char ** argv) {
const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1);

llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_cache_update (ctx);

n_past -= bd;
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
}

llama_batch_clear(batch);
Expand Down Expand Up @@ -179,10 +180,12 @@ int main(int argc, char ** argv) {

LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);

llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);

n_past -= n_discard;
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;

llama_batch_clear(batch);

Expand All @@ -208,10 +211,12 @@ int main(int argc, char ** argv) {
if (n_discard > 0) {
LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);

llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
llama_kv_cache_defrag (ctx);
llama_kv_cache_update (ctx);

n_past -= n_discard;
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
}
}

Expand Down
8 changes: 4 additions & 4 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1622,8 +1622,8 @@ struct llama_server_context
{"n_system_tokens", system_tokens.size()},
{"n_cache_tokens", slot.cache_tokens.size()}
});
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_shift(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);

for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
{
Expand Down Expand Up @@ -1919,9 +1919,9 @@ struct llama_server_context
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);

llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);

slot.n_past_se -= bd;

Expand Down
Loading
Loading