Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

passkey : add "self-extend"-like context extension #4810

Merged
merged 2 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ int main(int argc, char ** argv) {
gpt_params params;

if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH N_JUNK I_POS SEED\n" , argv[0]);
printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]);
return 1 ;
}

int seed = -1;

int n_junk = 250; // number of times to repeat the junk text
int n_keep = 32; // number of tokens in the prompt prefix
int n_grp = 1; // if more than 1 - perform LongLM SelfExtend
int i_pos = -1; // position of the passkey in the junk text

if (argc >= 2) {
Expand All @@ -29,11 +30,15 @@ int main(int argc, char ** argv) {
}

if (argc >= 4) {
i_pos = std::stoi(argv[3]);
n_grp = std::stoi(argv[3]);
}

if (argc >= 5) {
seed = std::stoi(argv[4]);
i_pos = std::stoi(argv[4]);
}

if (argc >= 6) {
seed = std::stoi(argv[5]);
}

if (seed == -1) {
Expand Down Expand Up @@ -86,11 +91,13 @@ int main(int argc, char ** argv) {
llama_context_params ctx_params = llama_context_default_params();

ctx_params.seed = seed;
ctx_params.n_ctx = llama_n_ctx_train(model) + n_keep;
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep;
ctx_params.n_batch = 512;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");

llama_context * ctx = llama_new_context_with_model(model, ctx_params);

if (ctx == NULL) {
Expand All @@ -113,11 +120,12 @@ int main(int argc, char ** argv) {
// total length of the sequences including the prompt
const int n_len = n_tokens_all + n_predict;

const int n_ctx = llama_n_ctx(ctx) - n_keep;
const int n_kv_req = llama_n_ctx(ctx);
const int n_batch = ctx_params.n_batch;
const int n_ctx = llama_n_ctx(ctx) - n_keep;
const int n_kv_req = llama_n_ctx(ctx);
const int n_batch = ctx_params.n_batch;
const int n_batch_grp = ctx_params.n_batch/n_grp;

LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req);
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);

// print the prompt token-by-token

Expand All @@ -132,6 +140,17 @@ int main(int argc, char ** argv) {

// fill the KV cache
for (int i = 0; i < n_ctx; i += n_batch) {
if (i > 0 && n_grp > 1) {
// if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp
const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1);

llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);

n_past -= bd;
}

llama_batch_clear(batch);

for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) {
Expand Down
34 changes: 34 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,28 @@ static void llama_kv_cache_seq_shift(
cache.head = new_head != cache.size ? new_head : 0;
}

static void llama_kv_cache_seq_div(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d) {
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.has_shift = true;

{
llama_pos p_old = cache.cells[i].pos;
cache.cells[i].pos /= d;
cache.cells[i].delta += cache.cells[i].pos - p_old;
}
}
}
}

//
// model loading and saving
//
Expand Down Expand Up @@ -10140,9 +10162,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
}

void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (delta == 0) {
return;
}

llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta);
}

void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
if (d == 1) {
return;
}

llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d);
}

// Returns the *maximum* size of the state
size_t llama_get_state_size(const struct llama_context * ctx) {
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
Expand Down
7 changes: 7 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,13 @@ extern "C" {
llama_pos p1,
llama_pos delta);

LLAMA_API void llama_kv_cache_seq_div(
struct llama_context * ctx,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1,
int d);

//
// State / sessions
//
Expand Down