Skip to content

Commit

Permalink
whisper : separate self and cross attention memory
Browse files Browse the repository at this point in the history
Initial step needed for supporting parallel decoders
  • Loading branch information
ggerganov committed Jan 8, 2023
1 parent e3c6416 commit 20b64fc
Showing 1 changed file with 89 additions and 46 deletions.
135 changes: 89 additions & 46 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,19 @@ static const std::map<e_model, size_t> MEM_REQ_MODEL = {
};

static const std::map<e_model, size_t> MEM_REQ_MEMORY = {
{ MODEL_TINY, 12ull*MB },
{ MODEL_BASE, 24ull*MB },
{ MODEL_SMALL, 70ull*MB },
{ MODEL_MEDIUM, 184ull*MB },
{ MODEL_LARGE, 306ull*MB },
{ MODEL_TINY, 3ull*MB },
{ MODEL_BASE, 6ull*MB },
{ MODEL_SMALL, 16ull*MB },
{ MODEL_MEDIUM, 43ull*MB },
{ MODEL_LARGE, 71ull*MB },
};

static const std::map<e_model, size_t> MEM_REQ_MEMORY_CROSS = {
{ MODEL_TINY, 9ull*MB },
{ MODEL_BASE, 18ull*MB },
{ MODEL_SMALL, 53ull*MB },
{ MODEL_MEDIUM, 141ull*MB },
{ MODEL_LARGE, 235ull*MB },
};

static const std::map<e_model, size_t> MEM_REQ_ENCODE = {
Expand Down Expand Up @@ -391,22 +399,27 @@ struct whisper_model {
std::vector<whisper_layer_encoder> layers_encoder;
std::vector<whisper_layer_decoder> layers_decoder;

// key + value memory
// key + value memory for self attention
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;

// key + value memory for cross attention
struct ggml_tensor * memory_cross_k;
struct ggml_tensor * memory_cross_v;

// context
struct ggml_context * ctx;
struct ggml_context * ctx_mem;
struct ggml_context * ctx_mem_cross;

// tensors
int n_loaded;
std::map<std::string, struct ggml_tensor *> tensors;
};

struct whisper_decoder_data {
};

struct whisper_context {
int64_t t_load_us = 0;
int64_t t_mel_us = 0;
Expand All @@ -417,6 +430,7 @@ struct whisper_context {

std::vector<uint8_t> * buf_model; // the model buffer is read-only and can be shared between processors
std::vector<uint8_t> buf_memory;
std::vector<uint8_t> buf_memory_cross;
std::vector<uint8_t> buf_compute;
std::vector<uint8_t> buf_compute_layer;

Expand Down Expand Up @@ -533,6 +547,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.buf_model = new std::vector<uint8_t>();
wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type));
wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type));
wctx.buf_memory_cross.resize(MEM_REQ_MEMORY_CROSS.at(model.type));
wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)));
wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)));
}
Expand Down Expand Up @@ -631,6 +646,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
const size_t mem_required =
wctx.buf_model->size() +
wctx.buf_memory.size() +
wctx.buf_memory_cross.size() +
wctx.buf_compute.size() +
wctx.buf_compute_layer.size();

Expand Down Expand Up @@ -964,41 +980,58 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
}
}

// create the ggml memory context
// create the ggml context for the key/value memory (self-attention)
{
struct ggml_init_params params;
params.mem_size = wctx.buf_memory.size();
params.mem_buffer = wctx.buf_memory.data();

model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
auto & ctx = model.ctx_mem;

ctx = ggml_init(params);
if (!ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}

// key + value memory
{
auto & ctx = model.ctx_mem;

const auto & hparams = model.hparams;
{
const auto & hparams = model.hparams;

const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;

// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;

model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}

// key/value memory for the cross-attention layer
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
}

// create the ggml context for the key/value memory (cross-attention)
{
struct ggml_init_params params;
params.mem_size = wctx.buf_memory_cross.size();
params.mem_buffer = wctx.buf_memory_cross.data();

auto & ctx = model.ctx_mem_cross;

ctx = ggml_init(params);
if (!ctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}

{
const int n_audio_ctx = hparams.n_audio_ctx;
const auto & hparams = model.hparams;

const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_audio_ctx = hparams.n_audio_ctx;

const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
Expand All @@ -1007,10 +1040,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}

const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);

fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0);
}

Expand Down Expand Up @@ -2344,6 +2375,9 @@ void whisper_free(struct whisper_context * ctx) {
if (ctx->model.ctx_mem) {
ggml_free(ctx->model.ctx_mem);
}
if (ctx->model.ctx_mem_cross) {
ggml_free(ctx->model.ctx_mem_cross);
}
if (ctx->buf_model) {
delete ctx->buf_model;
}
Expand Down Expand Up @@ -3380,48 +3414,57 @@ int whisper_full_parallel(

auto & model = ctxs[i].model;

// create the ggml memory context
// separate key + value memory for each processor (self-attention)
{
struct ggml_init_params params;
params.mem_size = ctxs[i].buf_memory.size();
params.mem_buffer = ctxs[i].buf_memory.data();

model.ctx_mem = ggml_init(params);
if (!model.ctx_mem) {
auto & mctx = model.ctx_mem;
mctx = ggml_init(params);
if (!mctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
}

// separate key + value memory for each processor
{
auto & mctx = model.ctx_mem;

const auto & hparams = model.hparams;
{
const auto & hparams = model.hparams;

const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;
const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_text_ctx = hparams.n_text_ctx;

// key/value memory for the self-attention layer
{
const int n_mem = n_text_layer*n_text_ctx;
const int n_elements = n_text_state*n_mem;

model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
}

// key/value memory for the cross-attention layer
{
const int n_audio_ctx = hparams.n_audio_ctx;

const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;
// separate key + value memory for each processor (cross-attention)
{
struct ggml_init_params params;
params.mem_size = ctxs[i].buf_memory_cross.size();
params.mem_buffer = ctxs[i].buf_memory_cross.data();

model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
auto & mctx = model.ctx_mem_cross;
mctx = ggml_init(params);
if (!mctx) {
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
return false;
}
const auto & hparams = model.hparams;

const int n_text_state = hparams.n_text_state;
const int n_text_layer = hparams.n_text_layer;
const int n_audio_ctx = hparams.n_audio_ctx;

const int n_mem = n_text_layer*n_audio_ctx;
const int n_elements = n_text_state*n_mem;

model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements);
}
}

Expand Down

0 comments on commit 20b64fc

Please sign in to comment.