Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : greatly reduce output buffer memory usage #6122

Merged
merged 26 commits into from
Mar 26, 2024
Merged
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1fd1918
llama : greatly reduce logits memory usage
compilade Mar 15, 2024
98914c0
llama : more compact state saving and reloading
compilade Mar 15, 2024
705d393
llama : fix lctx.n_outputs not being set before building graph
compilade Mar 16, 2024
25981fc
perplexity : adapt to the logits API changes
compilade Mar 17, 2024
17b45c9
perplexity : fix Winogrande, use correct logits for second choice start
compilade Mar 17, 2024
d0129e8
perplexity : normalize spaces and punctuation in Winogrande sentences
compilade Mar 17, 2024
487f89e
llama : fix embedding conditions
compilade Mar 17, 2024
408fcb0
llama : fix llama_get_embeddings_ith when the resulting id is 0
compilade Mar 17, 2024
e19cb3a
llama : fix wrong n_outputs in llama_set_inputs
compilade Mar 17, 2024
a57fa7f
llama : fix not-skipping outputs of non-causal models
compilade Mar 18, 2024
711b0bc
llama : fix running a batch with n_outputs == 0
compilade Mar 18, 2024
d100502
llama : keep same graph topology even when n_outputs == 0
compilade Mar 18, 2024
99c37cc
ggml : saner ggml_can_repeat with empty tensors
compilade Mar 18, 2024
6bf7f3f
ggml : do not multi-thread ops returning empty tensors
compilade Mar 18, 2024
09bb15a
ggml : make ggml_is_empty public and work with views
compilade Mar 19, 2024
4551e7e
llama : use a vector for ctx->output_ids
compilade Mar 19, 2024
8b826c5
ggml : skip empty tensors in all backends
compilade Mar 19, 2024
d04cfaf
llama : fix llama_output_reserve nullptr deref when new_size is 0
compilade Mar 19, 2024
8f70dcb
perplexity : make Winogrande work as it does on master
compilade Mar 19, 2024
615a3a4
llama : clearer error messages for invalid logits or embeddings ids
compilade Mar 19, 2024
7d8d6b5
llama : handle errors from llama_output_reserve at call sites
compilade Mar 21, 2024
5f33a67
perplexity : make hellaswag and multiple-choice outputs identical to …
compilade Mar 21, 2024
ffa9abd
Merge branch 'master' into compilade/smaller-output-buffer
compilade Mar 25, 2024
e9095ac
llama : allow loading state saved with a different ctx size
compilade Mar 26, 2024
5027d81
llama : minor
ggerganov Mar 26, 2024
20248e8
readme : update recent API changes, and warn about Vulkan
compilade Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
llama : fix lctx.n_outputs not being set before building graph
  • Loading branch information
compilade committed Mar 17, 2024
commit 705d3937eaa1f1f370fda188564405e358905d8c
149 changes: 80 additions & 69 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2103,7 +2103,7 @@ struct llama_context {

int32_t * output_ids = nullptr; // map token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous batch
int32_t n_outputs = 0; // number of actually-used outputs in the current or previous ubatch

bool logits_all = false;

Expand Down Expand Up @@ -8985,24 +8985,25 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
int32_t * data = (int32_t *) lctx.inp_out_ids->data;

int32_t n_outputs = 0;
if (batch.logits) {
int32_t n_outputs = 0;
for (int i = 0; i < n_tokens; ++i) {
if (batch.logits[i]) {
data[n_outputs++] = i;
}
}
lctx.n_outputs = n_outputs;
} else if (lctx.logits_all || (cparams.embeddings && hparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
for (int i = 0; i < n_tokens; ++i) {
data[i] = i;
}
lctx.n_outputs = n_tokens;
n_outputs = n_tokens;
} else {
// only keep last output
data[0] = n_tokens - 1;
lctx.n_outputs = 1;
n_outputs = 1;
}
// the graph needs the have been passed the correct number of outputs
GGML_ASSERT(lctx.n_outputs == n_outputs);
}

GGML_ASSERT(
Expand Down Expand Up @@ -9202,6 +9203,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
const auto n_embd = hparams.n_embd;
const int64_t capacity = lctx.output_size;

// TODO: use a per-batch flag for logits presence instead
const bool has_logits = cparams.causal_attn;
const bool has_embd = cparams.embeddings && (!hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);

Expand All @@ -9221,10 +9223,11 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {
const size_t buf_output_size = (lctx.logits_size + lctx.embd_size)*sizeof(float);

if (lctx.buf_output) {
#ifndef NDEBUG
// This doesn't happen often
// #ifndef NDEBUG
const size_t prev_size = ggml_backend_buffer_get_size(lctx.buf_output);
fprintf(stderr, "%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size/ 1024.0 / 1024.0);
#endif
LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, buf_output_size / 1024.0 / 1024.0);
// #endif
ggml_backend_buffer_free(lctx.buf_output);
lctx.buf_output = nullptr;
lctx.logits = nullptr;
Expand All @@ -9246,7 +9249,7 @@ static void llama_output_reserve(llama_context & lctx, int32_t n_outputs) {

ggml_backend_buffer_clear(lctx.buf_output, 0);

lctx.n_outputs = n_outputs; // also set in llama_set_inputs() before a batch
lctx.n_outputs = 0;
}


Expand Down Expand Up @@ -9325,8 +9328,8 @@ static int llama_decode_internal(
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;

int32_t n_logits = 0;
int32_t n_logits_prev = 0;
int32_t n_outputs = 0;
int32_t n_outputs_prev = 0;

const auto n_ubatch = cparams.n_ubatch;

Expand All @@ -9338,27 +9341,25 @@ static int llama_decode_internal(
// reserve output buffer
if (batch_all.logits) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch_all.logits[i]) {
n_logits++;
}
n_outputs += batch_all.logits[i] != 0;
}
llama_output_reserve(lctx, n_logits);
llama_output_reserve(lctx, n_outputs);
int32_t i_logits = 0;
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch_all.logits[i]) {
lctx.output_ids[i] = i_logits++;
}
}
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
n_logits = n_tokens_all;
llama_output_reserve(lctx, n_logits);
n_outputs = n_tokens_all;
llama_output_reserve(lctx, n_outputs);
for (uint32_t i = 0; i < n_tokens_all; ++i) {
lctx.output_ids[i] = i;
}
} else {
// keep last logits only
n_logits = 1;
llama_output_reserve(lctx, n_logits);
// keep last output only
n_outputs = 1;
llama_output_reserve(lctx, n_outputs);
lctx.output_ids[0] = 0;
}

Expand All @@ -9377,6 +9378,27 @@ static int llama_decode_internal(
/* .all_seq_id = */ batch_all.all_seq_id,
};

// count the outputs in this u_batch
{
int32_t n_outputs_new = 0;

if (u_batch.logits) {
for (uint32_t i = 0; i < n_tokens; i++) {
n_outputs_new += u_batch.logits[i] != 0;
}
} else if (lctx.logits_all) {
n_outputs_new = n_tokens;
} else {
// keep last output only
if (cur_token + n_tokens >= n_tokens_all) {
n_outputs_new = 1;
}
}

// needs to happen before the graph is built
lctx.n_outputs = n_outputs_new;
}

int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
GGML_ASSERT(n_threads > 0);

Expand Down Expand Up @@ -9451,18 +9473,26 @@ static int llama_decode_internal(
embd = gf->nodes[gf->n_nodes - 1];

GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
// TODO: graph view to ignore the logits when not needed
} else {
if (strcmp(res->name, "result_output") == 0) {
// the token embeddings could be the second to last tensor, or any of the previous tensors
// NOTE: see build_result_output() for an idea of up to how many tensors to skip
for (int i = 3; strcmp(embd->name, "result_norm") != 0 && i <= 10; ++i) {
embd = gf->nodes[gf->n_nodes - i];
}
GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
} else {
GGML_ASSERT(false && "missing result_output tensor");
} else if (cparams.embeddings) {
// the embeddings could be in the second to last tensor, or any of the previous tensors
int i_embd = gf->n_nodes - 2;
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
i_embd = gf->n_nodes - i;
if (i_embd < 0) { break; }
embd = gf->nodes[i_embd];
}
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");

// TODO: use a per-batch flag to know when to skip logits while keeping embeddings
if (!cparams.causal_attn) {
res = nullptr; // do not extract logits when not needed
// skip computing logits
// TODO: is this safe?
gf->n_nodes = i_embd + 1;
}
} else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);

Expand Down Expand Up @@ -9505,56 +9535,36 @@ static int llama_decode_internal(
//}

// extract logits
// TODO: do not compute and extract logits if only embeddings are needed
// update the graphs to skip "result_output" if logits are not needed
if (res) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
int32_t new_logits = 0;

if (u_batch.logits) {
for (uint32_t i = 0; i < n_tokens; i++) {
if (u_batch.logits[i]) {
new_logits++;
}
}
} else if (lctx.logits_all) {
new_logits += n_tokens;
} else {
// keep last logits only
if (cur_token + n_tokens >= n_tokens_all) {
new_logits += 1;
}
}
float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
const int32_t n_outputs_new = lctx.n_outputs;

if (new_logits) {
GGML_ASSERT(new_logits <= n_logits);
GGML_ASSERT((n_logits_prev+new_logits)*n_vocab <= (int64_t) lctx.logits_size);
ggml_backend_tensor_get_async(backend_res, res, lctx.logits, n_logits_prev*n_vocab*sizeof(float), new_logits*n_vocab*sizeof(float));
n_logits_prev += new_logits;
if (n_outputs_new) {
GGML_ASSERT(n_outputs_prev+n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev+n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
}
}

// extract embeddings
if (cparams.embeddings && embd) {
if (embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);

switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
auto & embd_out = lctx.embd;

if (u_batch.logits) {
//embd_out.resize(n_embd * n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
if (u_batch.logits[i] == 0) {
continue;
}
// FIXME
ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
}
float * embd_out = lctx.embd + n_outputs_prev*n_embd;
const int32_t n_outputs_new = lctx.n_outputs;

if (n_outputs_new) {
GGML_ASSERT(n_outputs_prev+n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev+n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_CLS:
Expand All @@ -9581,6 +9591,7 @@ static int llama_decode_internal(
} break;
}
}
n_outputs_prev += lctx.n_outputs;
}

// wait for the computation to finish (automatically done when obtaining the model output)
Expand Down Expand Up @@ -14639,11 +14650,11 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
const int32_t j = ctx->output_ids[i];

llama_synchronize(ctx);

if (ctx->logits && 0 <= j && j < ctx->n_outputs) {
if (ctx->logits && 0 <= j && (size_t) j < ctx->output_size) {
return ctx->logits + j*ctx->model.hparams.n_vocab;
}
LLAMA_LOG_ERROR("%s: invalid logits id %i\n", __func__, i);
LLAMA_LOG_ERROR("%s: invalid logits id %i, reason: %s (j=%i, output_size=%li)\n",
__func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big", j, ctx->output_size);
#ifndef NDEBUG
GGML_ASSERT(false);
#endif
Expand All @@ -14661,7 +14672,7 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {

llama_synchronize(ctx);

if (ctx->embd && 0 < j && j < ctx->n_outputs) {
if (ctx->embd && 0 < j && (size_t) j < ctx->output_size) {
return ctx->embd + j*ctx->model.hparams.n_embd;
}
LLAMA_LOG_ERROR("%s: invalid embeddings id %i\n", __func__, i);
Expand Down