From abe104b80bb1871f3219c894bd062902fb11c2b4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 18 Dec 2022 13:58:25 +0200 Subject: [PATCH 01/23] whisper : prepare infra for new decoding strategies --- whisper.cpp | 87 +++++++++++++++++++++++++++++++++++++++++------------ whisper.h | 5 +-- 2 files changed, 71 insertions(+), 21 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index a64505693f7..da43f5d3a80 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -16,6 +16,14 @@ #include #include +#define WHISPER_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + #define USE_FLASH_ATTN //#define USE_FLASH_FF @@ -423,8 +431,9 @@ struct whisper_context { std::vector logits; std::vector result_all; + std::vector prompt_past; - std::vector prompt_past; + std::vector work_logits; // used to avoid allocations // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; @@ -2689,12 +2698,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", + /*.suppress_blank =*/ true, + /*.greedy =*/ { - /*.n_past =*/ 0, + /*.dummy =*/ 0, }, /*.beam_search =*/ { - /*.n_past =*/ -1, /*.beam_width =*/ -1, /*.n_best =*/ -1, }, @@ -2738,12 +2748,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", + /*.suppress_blank =*/ true, + /*.greedy =*/ { - /*.n_past =*/ -1, + /*.dummy =*/ 0, }, /*.beam_search =*/ { - /*.n_past =*/ 0, /*.beam_width =*/ 10, /*.n_best =*/ 5, }, @@ -2822,6 +2833,50 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { return res; } +static struct whisper_token_data whisper_sample_next_token( + struct whisper_context * ctx, + struct whisper_full_params params, + const std::vector & prompt, + const std::vector & tokens_cur) { + struct whisper_token_data result = {}; + + const auto & vocab = ctx->vocab; + + const bool is_initial = tokens_cur.size() == 0; + const int n_logits = vocab.id_to_token.size(); + + WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab); + + // extract the logits for the last token + // we will be mutating and therefore we don't want to use the ctx->logits buffer directly + auto & logits = ctx->work_logits; + { + logits.resize(n_logits); + memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float)); + } + + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 + // TODO: apply logit filters here + { + } + + switch (params.strategy) { + case WHISPER_SAMPLING_GREEDY: + { + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274 + // TODO: implement + result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364 + // TODO: implement + } break; + } + + return result; +} + int whisper_full( struct whisper_context * ctx, struct whisper_full_params params, @@ -2955,7 +3010,6 @@ int whisper_full( return -4; } - int n_past = 0; prompt.clear(); // if we have already generated some text, use it as a prompt to condition the next generation @@ -2971,8 +3025,6 @@ int whisper_full( prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - int seek_delta = 100*WHISPER_CHUNK_SIZE; - // print the prompt //printf("\n\n"); //for (int i = 0; i < prompt.size(); i++) { @@ -2980,11 +3032,14 @@ int whisper_full( //} //printf("\n\n"); + int n_past = 0; + int seek_delta = 100*WHISPER_CHUNK_SIZE; + // the accumulated transcription in the current interation int result_len = 0; tokens_cur.clear(); - bool failed = false; + bool failed = false; // has the current segment failed to decode? bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { @@ -2996,15 +3051,10 @@ int whisper_full( n_past += prompt.size(); prompt.clear(); - // very basic greedy sampling strategy: - // - // - always take the most probable token - // - // more sophisticated sampling strategies could be implemented here, but we keep it simple - // feel free to experiment! - // + // sample the next token based on the selected decoding strategy + parameters + // also, update the sliding window position based on the sampled timestamp tokens { - const auto token = (i == 0) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); + const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur); // timestamp token - update sliding window if (token.id > whisper_token_beg(ctx)) { @@ -3059,8 +3109,7 @@ int whisper_full( } // sometimes, the decoding can get stuck in a repetition loop - // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance - // the sliding window by 1 second + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { failed = true; break; diff --git a/whisper.h b/whisper.h index 63f61af5114..77fa89a2610 100644 --- a/whisper.h +++ b/whisper.h @@ -274,12 +274,13 @@ extern "C" { // for auto-detection, set to nullptr, "" or "auto" const char * language; + bool suppress_blank; + struct { - int n_past; + int dummy; } greedy; struct { - int n_past; int beam_width; int n_best; } beam_search; From 2d8d3724b8df3459b8740c1a587d08f81960fc35 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 Dec 2022 22:31:47 +0200 Subject: [PATCH 02/23] whisper : apply logit filters and compute logprobs --- .gitignore | 2 + whisper.cpp | 115 ++++++++++++++++++++++++++++++++++++++++++++++++++-- whisper.h | 3 ++ 3 files changed, 117 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 8a495199e75..5ca3702c331 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ build/ build-em/ build-debug/ build-release/ +build-static/ build-sanitize-addr/ build-sanitize-thread/ @@ -18,6 +19,7 @@ build-sanitize-thread/ /bench sync.sh +libwhisper.a libwhisper.so compile_commands.json diff --git a/whisper.cpp b/whisper.cpp index da43f5d3a80..d93596f85c3 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -433,7 +433,9 @@ struct whisper_context { std::vector result_all; std::vector prompt_past; - std::vector work_logits; // used to avoid allocations + // used to avoid allocations + std::vector work_logits; + std::vector work_logprobs; // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; @@ -2700,6 +2702,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.suppress_blank =*/ true, + /*.max_initial_timestamp =*/ 1.0, + /*.greedy =*/ { /*.dummy =*/ 0, }, @@ -2750,6 +2754,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.suppress_blank =*/ true, + /*.max_initial_timestamp =*/ 1.0, + /*.greedy =*/ { /*.dummy =*/ 0, }, @@ -2849,17 +2855,120 @@ static struct whisper_token_data whisper_sample_next_token( // extract the logits for the last token // we will be mutating and therefore we don't want to use the ctx->logits buffer directly - auto & logits = ctx->work_logits; + auto & logits = ctx->work_logits; + auto & logprobs = ctx->work_logprobs; { logits.resize(n_logits); memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float)); + + // will be populated a bit later + logprobs.resize(n_logits); } + // apply logit filters here // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 - // TODO: apply logit filters here { + // suppress blank + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 + if (params.suppress_blank) { + if (is_initial) { + logits[vocab.token_eot] = -INFINITY; + logits[vocab.token_to_id.at(" ")] = -INFINITY; + } + } + + // suppress <|notimestamps|> token + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 + logits[vocab.token_not] = -INFINITY; + + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 + { + const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; + const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + for (int i = vocab.token_beg; i < n_logits; ++ i) { + logits[i] = -INFINITY; + } + } else { + for (int i = 0; i < vocab.token_eot; ++ i) { + logits[i] = -INFINITY; + } + } + } + } + + // the initial timestamp cannot be larger than max_initial_timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial && params.max_initial_timestamp > 0.0f) { + const float precision = float(WHISPER_CHUNK_SIZE)/ctx->model.hparams.n_audio_ctx; + const int tid0 = std::round(params.max_initial_timestamp/precision); + + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) { + logits[i] = -INFINITY; + } + } + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++ i) { + logsumexp += expf(logits[i] - logit_max); + } + logsumexp = logf(logsumexp) + logit_max; + for (int i = 0; i < n_logits; ++ i) { + logprobs[i] = logits[i] - logsumexp; + } + } + + // if sum of probability over timestamps is above any other token, sample timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 + { + // logsumexp over timestamps + float timestamp_logprob = -INFINITY; + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); + for (int i = vocab.token_beg; i < n_logits; ++ i) { + logsumexp += expf(logprobs[i] - logprob_max); + } + logsumexp = logf(logsumexp) + logprob_max; + timestamp_logprob = logsumexp; + } + + const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); + + if (timestamp_logprob > max_text_token_logprob) { + for (int i = 0; i < vocab.token_beg; ++ i) { + logits[i] = -INFINITY; + } + } + } } + // print first 100 logits - token string : logit + for (int i = 0; i < 100; i++) { + const auto token = vocab.id_to_token.at(i); + const auto logit = logits[i]; + printf("%s : %f\n", token.c_str(), logit); + } + + // "And", "and", " And", " and" + printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + switch (params.strategy) { case WHISPER_SAMPLING_GREEDY: { diff --git a/whisper.h b/whisper.h index 77fa89a2610..1386a9a1511 100644 --- a/whisper.h +++ b/whisper.h @@ -274,8 +274,11 @@ extern "C" { // for auto-detection, set to nullptr, "" or "auto" const char * language; + // common decoding parameters: bool suppress_blank; + float max_initial_timestamp; + struct { int dummy; } greedy; From 21559537e2fc97bb8f54460b38fed057f7d233be Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 8 Jan 2023 16:40:53 +0200 Subject: [PATCH 03/23] whisper : add whisper_get_logits() --- whisper.cpp | 25 ++++++++++++++++++++----- whisper.h | 13 +++++++++++-- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index d93596f85c3..863bfe60212 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -620,6 +620,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + wctx.work_logits.reserve(vocab.n_vocab); + wctx.work_logprobs.reserve(vocab.n_vocab); + vocab.probs_id.reserve(n_vocab); } @@ -1004,11 +1007,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements); } - const size_t memory_size = - ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + - ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0); } // load weights @@ -2580,6 +2583,10 @@ int whisper_is_multilingual(struct whisper_context * ctx) { return ctx->vocab.is_multilingual() ? 1 : 0; } +float * whisper_get_logits(struct whisper_context * ctx) { + return ctx->logits.data(); +} + float * whisper_get_probs(struct whisper_context * ctx) { return ctx->probs.data(); } @@ -2842,6 +2849,7 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { static struct whisper_token_data whisper_sample_next_token( struct whisper_context * ctx, struct whisper_full_params params, + double & sum_logprobs, const std::vector & prompt, const std::vector & tokens_cur) { struct whisper_token_data result = {}; @@ -2849,7 +2857,7 @@ static struct whisper_token_data whisper_sample_next_token( const auto & vocab = ctx->vocab; const bool is_initial = tokens_cur.size() == 0; - const int n_logits = vocab.id_to_token.size(); + const int n_logits = vocab.id_to_token.size(); WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab); @@ -2949,6 +2957,7 @@ static struct whisper_token_data whisper_sample_next_token( } } +#if 0 // print first 100 logits - token string : logit for (int i = 0; i < 100; i++) { const auto token = vocab.id_to_token.at(i); @@ -2968,6 +2977,7 @@ static struct whisper_token_data whisper_sample_next_token( printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); +#endif switch (params.strategy) { case WHISPER_SAMPLING_GREEDY: @@ -2983,6 +2993,9 @@ static struct whisper_token_data whisper_sample_next_token( } break; } + sum_logprobs += logprobs[result.id]; + printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1)); + return result; } @@ -3151,6 +3164,8 @@ int whisper_full( bool failed = false; // has the current segment failed to decode? bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? + double sum_logprobs = 0.0; + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); @@ -3163,7 +3178,7 @@ int whisper_full( // sample the next token based on the selected decoding strategy + parameters // also, update the sliding window position based on the sampled timestamp tokens { - const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur); + const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur); // timestamp token - update sliding window if (token.id > whisper_token_beg(ctx)) { diff --git a/whisper.h b/whisper.h index 1386a9a1511..ccc1ff96eca 100644 --- a/whisper.h +++ b/whisper.h @@ -145,7 +145,7 @@ extern "C" { // Token sampling methods. // These are provided for convenience and can be used after each call to whisper_decode(). - // You can also implement your own sampling method using the whisper_get_probs() function. + // You can also implement your own sampling method using the whiper_get_logits() or whisper_get_probs() functions. // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); @@ -192,7 +192,16 @@ extern "C" { WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); - // The probabilities for the next token + // Token logits obtained from the last call to whisper_decode() + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); + + // Token probabilities (i.e. softmax(logits)) obtained from the last call to whisper_decode() + // The probabilities for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab WHISPER_API float * whisper_get_probs(struct whisper_context * ctx); // Token Id -> String. Uses the vocabulary in the provided context From 523e0494a662e16090a9aab855d26e1dcbec3f6b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 8 Jan 2023 19:11:21 +0200 Subject: [PATCH 04/23] whisper : separate self and cross attention memory Initial step needed for supporting parallel decoders --- whisper.cpp | 135 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 89 insertions(+), 46 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 863bfe60212..d57dd9c5f09 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -150,11 +150,19 @@ static const std::map MEM_REQ_MODEL = { }; static const std::map MEM_REQ_MEMORY = { - { MODEL_TINY, 12ull*MB }, - { MODEL_BASE, 24ull*MB }, - { MODEL_SMALL, 70ull*MB }, - { MODEL_MEDIUM, 184ull*MB }, - { MODEL_LARGE, 306ull*MB }, + { MODEL_TINY, 3ull*MB }, + { MODEL_BASE, 6ull*MB }, + { MODEL_SMALL, 16ull*MB }, + { MODEL_MEDIUM, 43ull*MB }, + { MODEL_LARGE, 71ull*MB }, +}; + +static const std::map MEM_REQ_MEMORY_CROSS = { + { MODEL_TINY, 9ull*MB }, + { MODEL_BASE, 18ull*MB }, + { MODEL_SMALL, 53ull*MB }, + { MODEL_MEDIUM, 141ull*MB }, + { MODEL_LARGE, 235ull*MB }, }; static const std::map MEM_REQ_ENCODE = { @@ -391,22 +399,27 @@ struct whisper_model { std::vector layers_encoder; std::vector layers_decoder; - // key + value memory + // key + value memory for self attention struct ggml_tensor * memory_k; struct ggml_tensor * memory_v; + // key + value memory for cross attention struct ggml_tensor * memory_cross_k; struct ggml_tensor * memory_cross_v; // context struct ggml_context * ctx; struct ggml_context * ctx_mem; + struct ggml_context * ctx_mem_cross; // tensors int n_loaded; std::map tensors; }; +struct whisper_decoder_data { +}; + struct whisper_context { int64_t t_load_us = 0; int64_t t_mel_us = 0; @@ -417,6 +430,7 @@ struct whisper_context { std::vector * buf_model; // the model buffer is read-only and can be shared between processors std::vector buf_memory; + std::vector buf_memory_cross; std::vector buf_compute; std::vector buf_compute_layer; @@ -533,6 +547,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.buf_model = new std::vector(); wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); + wctx.buf_memory_cross.resize(MEM_REQ_MEMORY_CROSS.at(model.type)); wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); } @@ -631,6 +646,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con const size_t mem_required = wctx.buf_model->size() + wctx.buf_memory.size() + + wctx.buf_memory_cross.size() + wctx.buf_compute.size() + wctx.buf_compute_layer.size(); @@ -964,31 +980,27 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - // create the ggml memory context + // create the ggml context for the key/value memory (self-attention) { struct ggml_init_params params; params.mem_size = wctx.buf_memory.size(); params.mem_buffer = wctx.buf_memory.data(); - model.ctx_mem = ggml_init(params); - if (!model.ctx_mem) { + auto & ctx = model.ctx_mem; + + ctx = ggml_init(params); + if (!ctx) { fprintf(stderr, "%s: ggml_init() failed\n", __func__); return false; } - } - // key + value memory - { - auto & ctx = model.ctx_mem; - - const auto & hparams = model.hparams; + { + const auto & hparams = model.hparams; - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + const int n_text_ctx = hparams.n_text_ctx; - // key/value memory for the self-attention layer - { const int n_mem = n_text_layer*n_text_ctx; const int n_elements = n_text_state*n_mem; @@ -996,9 +1008,30 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements); } - // key/value memory for the cross-attention layer + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } + + // create the ggml context for the key/value memory (cross-attention) + { + struct ggml_init_params params; + params.mem_size = wctx.buf_memory_cross.size(); + params.mem_buffer = wctx.buf_memory_cross.data(); + + auto & ctx = model.ctx_mem_cross; + + ctx = ggml_init(params); + if (!ctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; + } + { - const int n_audio_ctx = hparams.n_audio_ctx; + const auto & hparams = model.hparams; + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + const int n_audio_ctx = hparams.n_audio_ctx; const int n_mem = n_text_layer*n_audio_ctx; const int n_elements = n_text_state*n_mem; @@ -1007,10 +1040,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements); } - const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); - fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0); } @@ -2345,6 +2376,9 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.ctx_mem) { ggml_free(ctx->model.ctx_mem); } + if (ctx->model.ctx_mem_cross) { + ggml_free(ctx->model.ctx_mem_cross); + } if (ctx->buf_model) { delete ctx->buf_model; } @@ -3381,48 +3415,57 @@ int whisper_full_parallel( auto & model = ctxs[i].model; - // create the ggml memory context + // separate key + value memory for each processor (self-attention) { struct ggml_init_params params; params.mem_size = ctxs[i].buf_memory.size(); params.mem_buffer = ctxs[i].buf_memory.data(); - model.ctx_mem = ggml_init(params); - if (!model.ctx_mem) { + auto & mctx = model.ctx_mem; + mctx = ggml_init(params); + if (!mctx) { fprintf(stderr, "%s: ggml_init() failed\n", __func__); return false; } - } - // separate key + value memory for each processor - { - auto & mctx = model.ctx_mem; - - const auto & hparams = model.hparams; + { + const auto & hparams = model.hparams; - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + const int n_text_ctx = hparams.n_text_ctx; - // key/value memory for the self-attention layer - { const int n_mem = n_text_layer*n_text_ctx; const int n_elements = n_text_state*n_mem; model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); } + } - // key/value memory for the cross-attention layer - { - const int n_audio_ctx = hparams.n_audio_ctx; - - const int n_mem = n_text_layer*n_audio_ctx; - const int n_elements = n_text_state*n_mem; + // separate key + value memory for each processor (cross-attention) + { + struct ggml_init_params params; + params.mem_size = ctxs[i].buf_memory_cross.size(); + params.mem_buffer = ctxs[i].buf_memory_cross.data(); - model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); + auto & mctx = model.ctx_mem_cross; + mctx = ggml_init(params); + if (!mctx) { + fprintf(stderr, "%s: ggml_init() failed\n", __func__); + return false; } + const auto & hparams = model.hparams; + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + const int n_audio_ctx = hparams.n_audio_ctx; + + const int n_mem = n_text_layer*n_audio_ctx; + const int n_elements = n_text_state*n_mem; + + model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); + model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); } } From 1163f266d225cd020568679d1e2c3decadb1e35f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 9 Jan 2023 19:02:13 +0200 Subject: [PATCH 05/23] whisper : move probs_id buffer to whisper_context --- whisper.cpp | 46 ++++++++++++++++++++++++---------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index d57dd9c5f09..5b7b3fef6d6 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -220,10 +220,6 @@ struct whisper_vocab { std::map token_to_id; std::map id_to_token; - // used to avoid memory allocations during sampling - // TODO: move to whisper_context in the future - std::vector> probs_id; - id token_eot = 50256; id token_sot = 50257; id token_prev = 50360; @@ -387,14 +383,14 @@ struct whisper_model { struct ggml_tensor * e_ln_b; // decoder.positional_embedding - struct ggml_tensor * d_pe; // DD + struct ggml_tensor * d_pe; // decoder.token_embedding - struct ggml_tensor * d_te; // DD + struct ggml_tensor * d_te; // decoder.ln - struct ggml_tensor * d_ln_w; // DD - struct ggml_tensor * d_ln_b; // DD + struct ggml_tensor * d_ln_w; + struct ggml_tensor * d_ln_b; std::vector layers_encoder; std::vector layers_decoder; @@ -451,6 +447,8 @@ struct whisper_context { std::vector work_logits; std::vector work_logprobs; + std::vector> probs_id; + // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; int64_t t_last; @@ -545,10 +543,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: type = %d\n", __func__, model.type); wctx.buf_model = new std::vector(); - wctx.buf_model->resize(MEM_REQ_MODEL.at(model.type)); - wctx.buf_memory.resize(MEM_REQ_MEMORY.at(model.type)); - wctx.buf_memory_cross.resize(MEM_REQ_MEMORY_CROSS.at(model.type)); - wctx.buf_compute.resize(std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + wctx.buf_model->resize (MEM_REQ_MODEL.at(model.type)); + wctx.buf_memory.resize (MEM_REQ_MEMORY.at(model.type)); + wctx.buf_memory_cross.resize (MEM_REQ_MEMORY_CROSS.at(model.type)); + wctx.buf_compute.resize (std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); } @@ -638,7 +636,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.work_logits.reserve(vocab.n_vocab); wctx.work_logprobs.reserve(vocab.n_vocab); - vocab.probs_id.reserve(n_vocab); + wctx.probs_id.reserve(n_vocab); } { @@ -1008,7 +1006,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements); } - const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } @@ -1900,17 +1898,19 @@ static bool whisper_decode( // the most basic sampling scheme - select the top token static whisper_token_data whisper_sample_best( - whisper_vocab & vocab, + whisper_context & ctx, const float * probs, - bool force_timestamp, - bool is_initial) { + bool force_timestamp, + bool is_initial) { whisper_token_data result = { 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; + const auto & vocab = ctx.vocab; + const int n_logits = vocab.n_vocab; - auto & probs_id = vocab.probs_id; + auto & probs_id = ctx.probs_id; probs_id.clear(); for (int i = 0; i < n_logits; i++) { @@ -2461,7 +2461,7 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); - const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); + const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2471,7 +2471,7 @@ struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { const int64_t t_start_sample_us = ggml_time_us(); - const auto res = whisper_sample_best(ctx->vocab, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); + const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; @@ -2558,7 +2558,9 @@ int whisper_lang_auto_detect( return -7; } - std::vector> probs_id; + auto & probs_id = ctx->probs_id; + probs_id.clear(); + for (const auto & kv : g_lang) { const auto token_lang = whisper_token_lang(ctx, kv.second.first); probs_id.emplace_back(ctx->probs[token_lang], kv.second.first); @@ -2566,7 +2568,7 @@ int whisper_lang_auto_detect( // sort descending { - using pair_type = decltype(probs_id)::value_type; + using pair_type = std::remove_reference::type::value_type; std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { return a.first > b.first; }); From ee58108df8d4f46b08a3497016e66effb288f0a6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 9 Jan 2023 20:42:32 +0200 Subject: [PATCH 06/23] whisper : refactor kv cache into separate struct --- whisper.cpp | 279 ++++++++++++++++++++++------------------------------ whisper.h | 8 -- 2 files changed, 116 insertions(+), 171 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 5b7b3fef6d6..a5b402bc975 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -149,7 +149,7 @@ static const std::map MEM_REQ_MODEL = { { MODEL_LARGE, 2952ull*MB }, }; -static const std::map MEM_REQ_MEMORY = { +static const std::map MEM_REQ_KV_SELF = { { MODEL_TINY, 3ull*MB }, { MODEL_BASE, 6ull*MB }, { MODEL_SMALL, 16ull*MB }, @@ -157,7 +157,7 @@ static const std::map MEM_REQ_MEMORY = { { MODEL_LARGE, 71ull*MB }, }; -static const std::map MEM_REQ_MEMORY_CROSS = { +static const std::map MEM_REQ_KV_CROSS = { { MODEL_TINY, 9ull*MB }, { MODEL_BASE, 18ull*MB }, { MODEL_SMALL, 53ull*MB }, @@ -361,6 +361,15 @@ struct whisper_layer_decoder { struct ggml_tensor * mlp_1_b; }; +struct whisper_kv_cache { + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx; + + std::vector buf; +}; + struct whisper_model { e_model type = MODEL_UNKNOWN; @@ -395,25 +404,21 @@ struct whisper_model { std::vector layers_encoder; std::vector layers_decoder; - // key + value memory for self attention - struct ggml_tensor * memory_k; - struct ggml_tensor * memory_v; - - // key + value memory for cross attention - struct ggml_tensor * memory_cross_k; - struct ggml_tensor * memory_cross_v; - // context struct ggml_context * ctx; - struct ggml_context * ctx_mem; - struct ggml_context * ctx_mem_cross; + + // the model memory buffer is read-only and can be shared between processors + std::vector * buf; // tensors int n_loaded; std::map tensors; }; -struct whisper_decoder_data { +struct whisper_decoder { +}; + +struct whisper_sequence { }; struct whisper_context { @@ -424,17 +429,18 @@ struct whisper_context { int64_t t_decode_us = 0; int64_t t_start_us = 0; - std::vector * buf_model; // the model buffer is read-only and can be shared between processors - std::vector buf_memory; - std::vector buf_memory_cross; - std::vector buf_compute; - std::vector buf_compute_layer; + // memory buffers used by encode / decode contexts + std::vector buf_compute; + std::vector buf_compute_layer; ggml_type wtype; // weight type (FP32 or FP16) whisper_model model; whisper_vocab vocab; + whisper_kv_cache kv_self; + whisper_kv_cache kv_cross; + whisper_mel mel; std::vector probs; @@ -464,6 +470,34 @@ static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); } +static bool init_kv_cache( + const struct whisper_hparams & hparams, + struct whisper_kv_cache & cache, + ggml_type wtype, + int n_ctx) { + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mem = n_text_layer*n_ctx; + const int n_elements = n_text_state*n_mem; + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + // load the model from a ggml file // // file format: @@ -542,12 +576,20 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); fprintf(stderr, "%s: type = %d\n", __func__, model.type); - wctx.buf_model = new std::vector(); - wctx.buf_model->resize (MEM_REQ_MODEL.at(model.type)); - wctx.buf_memory.resize (MEM_REQ_MEMORY.at(model.type)); - wctx.buf_memory_cross.resize (MEM_REQ_MEMORY_CROSS.at(model.type)); - wctx.buf_compute.resize (std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); - wctx.buf_compute_layer.resize(std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); + // for the big tensors, we have the option to store the data in 16-bit floats + // in order to save memory and also to speed up the computation + wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + const size_t scale = model.hparams.f16 ? 1 : 2; + + wctx.model.buf = new std::vector(); + wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); + + wctx.kv_self.buf.resize (scale*MEM_REQ_KV_SELF.at(model.type)); + wctx.kv_cross.buf.resize(scale*MEM_REQ_KV_CROSS.at(model.type)); + + wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); + wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); } // load mel filters @@ -642,23 +684,19 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con { // this is the total memory required to run the inference const size_t mem_required = - wctx.buf_model->size() + - wctx.buf_memory.size() + - wctx.buf_memory_cross.size() + + wctx.model.buf->size() + + wctx.kv_self.buf.size() + + wctx.kv_cross.buf.size() + wctx.buf_compute.size() + wctx.buf_compute_layer.size(); - fprintf(stderr, "%s: mem_required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); + fprintf(stderr, "%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); } - // for the big tensors, we have the option to store the data in 16-bit floats - // in order to save memory and also to speed up the computation - wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + size_t ctx_size = 0; const ggml_type wtype = wctx.wtype; - size_t ctx_size = 0; - { const auto & hparams = model.hparams; @@ -766,14 +804,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con ctx_size += (15 + 15*n_audio_layer + 24*n_text_layer)*256; // object overhead - fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + fprintf(stderr, "%s: model ctx = %7.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); } // create the ggml context { struct ggml_init_params params; - params.mem_size = wctx.buf_model->size(); - params.mem_buffer = wctx.buf_model->data(); + params.mem_size = wctx.model.buf->size(); + params.mem_buffer = wctx.model.buf->data(); model.ctx = ggml_init(params); if (!model.ctx) { @@ -978,69 +1016,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - // create the ggml context for the key/value memory (self-attention) - { - struct ggml_init_params params; - params.mem_size = wctx.buf_memory.size(); - params.mem_buffer = wctx.buf_memory.data(); - - auto & ctx = model.ctx_mem; - - ctx = ggml_init(params); - if (!ctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - - { - const auto & hparams = model.hparams; - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; - - const int n_mem = n_text_layer*n_text_ctx; - const int n_elements = n_text_state*n_mem; - - model.memory_k = ggml_new_tensor_1d(ctx, wtype, n_elements); - model.memory_v = ggml_new_tensor_1d(ctx, wtype, n_elements); - } - - const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); - fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + // TODO: move to decoder + if (!init_kv_cache(model.hparams, wctx.kv_self, wctx.wtype, model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__); + return false; } - // create the ggml context for the key/value memory (cross-attention) { - struct ggml_init_params params; - params.mem_size = wctx.buf_memory_cross.size(); - params.mem_buffer = wctx.buf_memory_cross.data(); - - auto & ctx = model.ctx_mem_cross; - - ctx = ggml_init(params); - if (!ctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - - { - const auto & hparams = model.hparams; - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_audio_ctx = hparams.n_audio_ctx; - - const int n_mem = n_text_layer*n_audio_ctx; - const int n_elements = n_text_state*n_mem; - - model.memory_cross_k = ggml_new_tensor_1d(ctx, wtype, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements); - } + const size_t memory_size = ggml_nbytes(wctx.kv_self.k) + ggml_nbytes(wctx.kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } - const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); + if (!init_kv_cache(model.hparams, wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__); + return false; + } - fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0); + { + const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); } // load weights @@ -1504,10 +1498,10 @@ static bool whisper_encode( Vcross), Vcross); - //struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - //struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); - struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_cross_k, n_state*n_ctx, (ggml_element_size(model.memory_cross_k)*n_state)*(il*n_ctx)); - struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_cross_v, n_state*n_ctx, (ggml_element_size(model.memory_cross_v)*n_state)*(il*n_ctx)); + //struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + //struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*hparams.n_audio_ctx + iter*n_ctx)); + struct ggml_tensor * k = ggml_view_1d(ctx0, wctx.kv_cross.k, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.k)*n_state)*(il*n_ctx)); + struct ggml_tensor * v = ggml_view_1d(ctx0, wctx.kv_cross.v, n_state*n_ctx, (ggml_element_size(wctx.kv_cross.v)*n_state)*(il*n_ctx)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcross, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcross, v)); @@ -1636,8 +1630,8 @@ static bool whisper_decode( // store key and value to memory { - struct ggml_tensor * k = ggml_view_1d(ctxL, model.memory_k, N*n_state, (ggml_element_size(model.memory_k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctxL, model.memory_v, N*n_state, (ggml_element_size(model.memory_v)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctxL, wctx.kv_self.k, N*n_state, (ggml_element_size(wctx.kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctxL, wctx.kv_self.v, N*n_state, (ggml_element_size(wctx.kv_self.v)*n_state)*(il*n_ctx + n_past)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); @@ -1655,7 +1649,7 @@ static bool whisper_decode( struct ggml_tensor * K = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state), + ggml_view_1d(ctxL, wctx.kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.k)*n_state), n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); @@ -1675,7 +1669,7 @@ static bool whisper_decode( struct ggml_tensor * V_trans = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state), + ggml_view_1d(ctxL, wctx.kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.v)*n_state), n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); @@ -1730,12 +1724,12 @@ static bool whisper_decode( // Kcross is already scaled struct ggml_tensor * Kcross = ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_cross_k, M*n_state, il*M*ggml_element_size(model.memory_cross_k)*n_state), + ggml_view_1d(ctxL, wctx.kv_cross.k, M*n_state, il*M*ggml_element_size(wctx.kv_cross.k)*n_state), n_state/n_head, n_head, M); struct ggml_tensor * Vcross = ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, model.memory_cross_v, M*n_state, il*M*ggml_element_size(model.memory_cross_v)*n_state), + ggml_view_1d(ctxL, wctx.kv_cross.v, M*n_state, il*M*ggml_element_size(wctx.kv_cross.v)*n_state), n_state/n_head, n_head, M); // ------ @@ -2373,14 +2367,14 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.ctx) { ggml_free(ctx->model.ctx); } - if (ctx->model.ctx_mem) { - ggml_free(ctx->model.ctx_mem); + if (ctx->model.buf) { + delete ctx->model.buf; } - if (ctx->model.ctx_mem_cross) { - ggml_free(ctx->model.ctx_mem_cross); + if (ctx->kv_self.ctx) { + ggml_free(ctx->kv_self.ctx); } - if (ctx->buf_model) { - delete ctx->buf_model; + if (ctx->kv_cross.ctx) { + ggml_free(ctx->kv_cross.ctx); } delete ctx; } @@ -2458,7 +2452,8 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return 0; } -struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { +// TODO: remove +static struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { const int64_t t_start_sample_us = ggml_time_us(); const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); @@ -2468,7 +2463,8 @@ struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { return res; } -struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { +// TODO: remove +static struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { const int64_t t_start_sample_us = ggml_time_us(); const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); @@ -3413,61 +3409,18 @@ int whisper_full_parallel( std::vector ctxs(n_processors - 1); for (int i = 0; i < n_processors - 1; ++i) { - ctxs[i] = *ctx; - - auto & model = ctxs[i].model; - - // separate key + value memory for each processor (self-attention) - { - struct ggml_init_params params; - params.mem_size = ctxs[i].buf_memory.size(); - params.mem_buffer = ctxs[i].buf_memory.data(); - - auto & mctx = model.ctx_mem; - mctx = ggml_init(params); - if (!mctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - - { - const auto & hparams = model.hparams; - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_text_ctx = hparams.n_text_ctx; + auto & ctx_p = ctxs[i]; - const int n_mem = n_text_layer*n_text_ctx; - const int n_elements = n_text_state*n_mem; + ctx_p = *ctx; - model.memory_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - model.memory_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - } + if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_self, ctx_p.wtype, ctx_p.model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__); + return false; } - // separate key + value memory for each processor (cross-attention) - { - struct ggml_init_params params; - params.mem_size = ctxs[i].buf_memory_cross.size(); - params.mem_buffer = ctxs[i].buf_memory_cross.data(); - - auto & mctx = model.ctx_mem_cross; - mctx = ggml_init(params); - if (!mctx) { - fprintf(stderr, "%s: ggml_init() failed\n", __func__); - return false; - } - const auto & hparams = model.hparams; - - const int n_text_state = hparams.n_text_state; - const int n_text_layer = hparams.n_text_layer; - const int n_audio_ctx = hparams.n_audio_ctx; - - const int n_mem = n_text_layer*n_audio_ctx; - const int n_elements = n_text_state*n_mem; - - model.memory_cross_k = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); - model.memory_cross_v = ggml_new_tensor_1d(mctx, ctx->wtype, n_elements); + if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_cross, ctx_p.wtype, ctx_p.model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__); + return false; } } diff --git a/whisper.h b/whisper.h index ccc1ff96eca..2983765275a 100644 --- a/whisper.h +++ b/whisper.h @@ -143,14 +143,6 @@ extern "C" { int n_past, int n_threads); - // Token sampling methods. - // These are provided for convenience and can be used after each call to whisper_decode(). - // You can also implement your own sampling method using the whiper_get_logits() or whisper_get_probs() functions. - // whisper_sample_best() returns the token with the highest probability - // whisper_sample_timestamp() returns the most probable timestamp token - WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); - WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial); - // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens From 628843c60aa0af8610047c8b902938937ea15da1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Jan 2023 19:31:11 +0200 Subject: [PATCH 07/23] whisper : move self-attention kv cache to whisper_decoder --- whisper.cpp | 192 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 128 insertions(+), 64 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index a5b402bc975..52c70dcddb5 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -24,8 +24,9 @@ } \ } while (0) -#define USE_FLASH_ATTN -//#define USE_FLASH_FF +#define WHISPER_USE_FLASH_ATTN +//#define WHISPER_USE_FLASH_FF +#define WHISPER_MAX_DECODERS 16 // available whisper models enum e_model { @@ -416,6 +417,7 @@ struct whisper_model { }; struct whisper_decoder { + whisper_kv_cache kv_self; }; struct whisper_sequence { @@ -429,20 +431,24 @@ struct whisper_context { int64_t t_decode_us = 0; int64_t t_start_us = 0; - // memory buffers used by encode / decode contexts - std::vector buf_compute; - std::vector buf_compute_layer; - ggml_type wtype; // weight type (FP32 or FP16) + whisper_mel mel; + whisper_model model; whisper_vocab vocab; - whisper_kv_cache kv_self; whisper_kv_cache kv_cross; - whisper_mel mel; + whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; + int selected_decoder_id = 0; + + // memory buffers used by encode / decode contexts + std::vector buf_compute; + std::vector buf_compute_layer; + + // decode output std::vector probs; std::vector logits; @@ -470,11 +476,14 @@ static void read_safe(whisper_model_loader * loader, T & dest) { loader->read(loader->context, &dest, sizeof(T)); } -static bool init_kv_cache( +static bool kv_cache_init( const struct whisper_hparams & hparams, + const size_t mem_bytes, struct whisper_kv_cache & cache, ggml_type wtype, int n_ctx) { + cache.buf.resize(mem_bytes); + struct ggml_init_params params; params.mem_size = cache.buf.size(); params.mem_buffer = cache.buf.data(); @@ -498,6 +507,41 @@ static bool init_kv_cache( return true; } +static bool kv_cache_reinit(struct whisper_kv_cache & cache) { + WHISPER_ASSERT(cache.ctx); + + const int n_elements = ggml_nelements(cache.k); + WHISPER_ASSERT(n_elements == ggml_nelements(cache.v)); + + const ggml_type wtype = cache.k->type; + WHISPER_ASSERT(wtype == cache.v->type); + + WHISPER_ASSERT(cache.buf.size() >= 2*n_elements*ggml_type_size(wtype)); + + struct ggml_init_params params; + params.mem_size = cache.buf.size(); + params.mem_buffer = cache.buf.data(); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + return true; +} + +static void kv_cache_free(struct whisper_kv_cache & cache) { + if (cache.ctx) { + ggml_free(cache.ctx); + cache.ctx = nullptr; + } +} + // load the model from a ggml file // // file format: @@ -563,6 +607,12 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.type = e_model::MODEL_LARGE; } + // for the big tensors, we have the option to store the data in 16-bit floats + // in order to save memory and also to speed up the computation + wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + const size_t scale = model.hparams.f16 ? 1 : 2; + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); @@ -576,17 +626,47 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con fprintf(stderr, "%s: f16 = %d\n", __func__, hparams.f16); fprintf(stderr, "%s: type = %d\n", __func__, model.type); - // for the big tensors, we have the option to store the data in 16-bit floats - // in order to save memory and also to speed up the computation - wctx.wtype = model.hparams.f16 ? GGML_TYPE_F16 : GGML_TYPE_F32; + // print memory requirements + { + // this is the total memory required to run the inference + const size_t mem_required = + scale*MEM_REQ_MODEL.at (model.type) + + scale*MEM_REQ_KV_CROSS.at (model.type) + + scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type)) + + scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type)); - const size_t scale = model.hparams.f16 ? 1 : 2; + // this is the memory required by one decoder + const size_t mem_required_decoder = + scale*MEM_REQ_KV_SELF.at(model.type); + + fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per decoder)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); + } wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); - wctx.kv_self.buf.resize (scale*MEM_REQ_KV_SELF.at(model.type)); - wctx.kv_cross.buf.resize(scale*MEM_REQ_KV_CROSS.at(model.type)); + wctx.selected_decoder_id = 0; + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.decoders[0].kv_self.k) + ggml_nbytes(wctx.decoders[0].kv_self.v); + fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } + + if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_CROSS.at(model.type), wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { + fprintf(stderr, "%s: kv_cache_init() failed for cross-attention cache\n", __func__); + return false; + } + + { + const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); + fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + } wctx.buf_compute.resize (scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type))); wctx.buf_compute_layer.resize(scale*std::max(MEM_REQ_ENCODE_LAYER.at(model.type), MEM_REQ_DECODE_LAYER.at(model.type))); @@ -673,26 +753,14 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx); - wctx.work_logits.reserve(vocab.n_vocab); + wctx.work_logits.reserve (vocab.n_vocab); wctx.work_logprobs.reserve(vocab.n_vocab); wctx.probs_id.reserve(n_vocab); } - { - // this is the total memory required to run the inference - const size_t mem_required = - wctx.model.buf->size() + - wctx.kv_self.buf.size() + - wctx.kv_cross.buf.size() + - wctx.buf_compute.size() + - wctx.buf_compute_layer.size(); - - fprintf(stderr, "%s: mem required = %7.2f MB\n", __func__, mem_required / 1024.0 / 1024.0); - } - size_t ctx_size = 0; const ggml_type wtype = wctx.wtype; @@ -1016,27 +1084,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - // TODO: move to decoder - if (!init_kv_cache(model.hparams, wctx.kv_self, wctx.wtype, model.hparams.n_text_ctx)) { - fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.kv_self.k) + ggml_nbytes(wctx.kv_self.v); - fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - - if (!init_kv_cache(model.hparams, wctx.kv_cross, wctx.wtype, model.hparams.n_audio_ctx)) { - fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__); - return false; - } - - { - const size_t memory_size = ggml_nbytes(wctx.kv_cross.k) + ggml_nbytes(wctx.kv_cross.v); - fprintf(stderr, "%s: kv cross size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); - } - // load weights { size_t total_size = 0; @@ -1266,7 +1313,7 @@ static bool whisper_encode( // ------ -#ifdef USE_FLASH_ATTN +#ifdef WHISPER_USE_FLASH_ATTN struct ggml_tensor * Q = ggml_permute(ctxL, ggml_cpy(ctxL, @@ -1377,7 +1424,7 @@ static bool whisper_encode( ggml_repeat(ctxL, layer.mlp_ln_b, cur)); } -#ifdef USE_FLASH_FF +#ifdef WHISPER_USE_FLASH_FF cur = ggml_flash_ff(ctxL, ggml_cpy(ctxL, cur, ggml_new_tensor_2d(ctxL, wctx.wtype, n_state, N)), layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); @@ -1539,6 +1586,13 @@ static bool whisper_decode( const auto & model = wctx.model; const auto & hparams = model.hparams; + WHISPER_ASSERT(wctx.selected_decoder_id >= 0); + WHISPER_ASSERT(wctx.selected_decoder_id < WHISPER_MAX_DECODERS); + + auto & kv_self = wctx.decoders[wctx.selected_decoder_id].kv_self; + + WHISPER_ASSERT(!!kv_self.ctx); + auto & logits_out = wctx.logits; auto & probs_out = wctx.probs; @@ -1630,8 +1684,8 @@ static bool whisper_decode( // store key and value to memory { - struct ggml_tensor * k = ggml_view_1d(ctxL, wctx.kv_self.k, N*n_state, (ggml_element_size(wctx.kv_self.k)*n_state)*(il*n_ctx + n_past)); - struct ggml_tensor * v = ggml_view_1d(ctxL, wctx.kv_self.v, N*n_state, (ggml_element_size(wctx.kv_self.v)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * k = ggml_view_1d(ctxL, kv_self.k, N*n_state, (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_1d(ctxL, kv_self.v, N*n_state, (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + n_past)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctxL, Vcur, v)); @@ -1649,7 +1703,7 @@ static bool whisper_decode( struct ggml_tensor * K = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, wctx.kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.k)*n_state), + ggml_view_1d(ctxL, kv_self.k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.k)*n_state), n_state/n_head, n_head, n_past + N), 0, 2, 1, 3); @@ -1669,7 +1723,7 @@ static bool whisper_decode( struct ggml_tensor * V_trans = ggml_permute(ctxL, ggml_reshape_3d(ctxL, - ggml_view_1d(ctxL, wctx.kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(wctx.kv_self.v)*n_state), + ggml_view_1d(ctxL, kv_self.v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(kv_self.v)*n_state), n_state/n_head, n_head, n_past + N), 1, 2, 0, 3); @@ -2370,12 +2424,14 @@ void whisper_free(struct whisper_context * ctx) { if (ctx->model.buf) { delete ctx->model.buf; } - if (ctx->kv_self.ctx) { - ggml_free(ctx->kv_self.ctx); - } if (ctx->kv_cross.ctx) { ggml_free(ctx->kv_cross.ctx); } + for (int i = 0; i < WHISPER_MAX_DECODERS; ++i) { + if (ctx->decoders[i].kv_self.ctx) { + ggml_free(ctx->decoders[i].kv_self.ctx); + } + } delete ctx; } } @@ -3413,14 +3469,16 @@ int whisper_full_parallel( ctx_p = *ctx; - if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_self, ctx_p.wtype, ctx_p.model.hparams.n_text_ctx)) { - fprintf(stderr, "%s: init_kv_cache() failed for self-attention cache\n", __func__); + if (!kv_cache_reinit(ctx_p.kv_cross)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention cache\n", __func__); return false; } - if (!init_kv_cache(ctx_p.model.hparams, ctx_p.kv_cross, ctx_p.wtype, ctx_p.model.hparams.n_audio_ctx)) { - fprintf(stderr, "%s: init_kv_cache() failed for cross-attention cache\n", __func__); - return false; + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention cache\n", __func__); + return false; + } } } @@ -3485,6 +3543,12 @@ int whisper_full_parallel( ctx->t_sample_us += ctxs[i].t_sample_us; ctx->t_encode_us += ctxs[i].t_encode_us; ctx->t_decode_us += ctxs[i].t_decode_us; + + kv_cache_free(ctx->kv_cross); + + for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { + kv_cache_free(ctx->decoders[j].kv_self); + } } // average the timings From 9551d7fabd690f0354c72bf871ceeeeaaf36e499 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Jan 2023 22:15:54 +0200 Subject: [PATCH 08/23] whisper : wip decoding parameters + strategies --- whisper.cpp | 172 +++++++++++++++++++++++++++++++++++++--------------- whisper.h | 15 ++++- 2 files changed, 135 insertions(+), 52 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 52c70dcddb5..6759cd2c0f0 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -416,11 +416,16 @@ struct whisper_model { std::map tensors; }; +struct whisper_sequence { + std::vector tokens; +}; + struct whisper_decoder { whisper_kv_cache kv_self; -}; -struct whisper_sequence { + whisper_sequence sequence; + + std::vector prompt; }; struct whisper_context { @@ -759,6 +764,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.work_logprobs.reserve(vocab.n_vocab); wctx.probs_id.reserve(n_vocab); + + wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); + wctx.decoders[0].prompt.reserve(model.hparams.n_text_ctx); } size_t ctx_size = 0; @@ -2766,46 +2774,54 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str case WHISPER_SAMPLING_GREEDY: { result = { - /*.strategy =*/ WHISPER_SAMPLING_GREEDY, + /*.strategy =*/ WHISPER_SAMPLING_GREEDY, - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.max_tokens =*/ 0, - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, - /*.language =*/ "en", + /*.language =*/ "en", + + /*.suppress_blank =*/ true, - /*.suppress_blank =*/ true, + /*.temperature =*/ 0.0f, + /*.max_initial_timestamp =*/ 1.0f, - /*.max_initial_timestamp =*/ 1.0, + /*.temperature_increment =*/ 0.2f, + /*.compression_ratio_threshold =*/ 2.4f, + /*.logprob_threshold =*/ -1.0f, + /*.no_speech_threshold =*/ 0.6f, - /*.greedy =*/ { - /*.dummy =*/ 0, + /*.greedy =*/ { + /*.best_of =*/ 5, }, - /*.beam_search =*/ { - /*.beam_width =*/ -1, - /*.n_best =*/ -1, + /*.beam_search =*/ { + /*.beam_size =*/ -1, + + /*.patience =*/ -1.0f, + /*.length_penalty =*/ -1.0f, }, /*.new_segment_callback =*/ nullptr, @@ -2847,17 +2863,25 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.language =*/ "en", - /*.suppress_blank =*/ true, + /*.suppress_blank =*/ true, + + /*.temperature =*/ 0.0f, + /*.max_initial_timestamp =*/ 1.0f, - /*.max_initial_timestamp =*/ 1.0, + /*.temperature_increment =*/ 0.2f, + /*.compression_ratio_threshold =*/ 2.4f, + /*.logprob_threshold =*/ -1.0f, + /*.no_speech_threshold =*/ 0.6f, - /*.greedy =*/ { - /*.dummy =*/ 0, + /*.greedy =*/ { + /*.best_of =*/ 5, }, - /*.beam_search =*/ { - /*.beam_width =*/ 10, - /*.n_best =*/ 5, + /*.beam_search =*/ { + /*.beam_size =*/ 5, + + /*.patience =*/ -1.0f, + /*.length_penalty =*/ -1.0f, }, /*.new_segment_callback =*/ nullptr, @@ -3142,6 +3166,45 @@ int whisper_full( return 0; } + // a set of temperatures to use + // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] + std::vector temperatures; + if (params.temperature_increment > 0.0f) { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_increment) { + temperatures.push_back(t); + } + } else { + temperatures.push_back(params.temperature); + } + + // initialize the decoders + int n_decoders = 1; + switch (params.strategy) { + case WHISPER_SAMPLING_GREEDY: + { + n_decoders = params.greedy.best_of; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); + } break; + }; + + for (int i = 1; i < n_decoders; i++) { + if (ctx->decoders[i].kv_self.ctx == nullptr) { + ctx->decoders[i].kv_self = ctx->decoders[0].kv_self; + if (!kv_cache_reinit(ctx->decoders[i].kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, i); + return -4; + } + + fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i); + + ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + ctx->decoders[i].prompt.reserve(ctx->decoders[0].prompt.capacity()); + } + } + // the accumulated text context so far auto & prompt_past = ctx->prompt_past; if (params.no_context) { @@ -3160,7 +3223,7 @@ int whisper_full( // overwrite audio_ctx, max allowed is hparams.n_audio_ctx if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { fprintf(stderr, "%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); - return -4; + return -5; } ctx->exp_n_audio_ctx = params.audio_ctx; @@ -3201,12 +3264,6 @@ int whisper_full( break; } - // if there is a very short audio segment left to process, we remove any past prompt since it tends - // to confuse the decoder and often make it repeat or hallucinate stuff - if (seek > seek_start && seek + 500 >= seek_end) { - prompt_past.clear(); - } - if (params.encoder_begin_callback) { if (params.encoder_begin_callback(ctx, params.encoder_begin_callback_user_data) == false) { fprintf(stderr, "%s: encoder_begin_callback returned false - aborting\n", __func__); @@ -3217,7 +3274,13 @@ int whisper_full( // encode audio features starting at offset seek if (whisper_encode(ctx, seek, params.n_threads) != 0) { fprintf(stderr, "%s: failed to encode\n", __func__); - return -4; + return -6; + } + + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) { + prompt_past.clear(); } prompt.clear(); @@ -3257,7 +3320,7 @@ int whisper_full( for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); - return -5; + return -7; } n_past += prompt.size(); @@ -3469,16 +3532,27 @@ int whisper_full_parallel( ctx_p = *ctx; + ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); + ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); + + ctx_p.work_logits.reserve (ctx_p.vocab.n_vocab); + ctx_p.work_logprobs.reserve(ctx_p.vocab.n_vocab); + + ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab); + if (!kv_cache_reinit(ctx_p.kv_cross)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention cache\n", __func__); + fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); return false; } for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention cache\n", __func__); + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); return false; } + + ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); + ctx_p.decoders[j].prompt.reserve(ctx_p.model.hparams.n_text_ctx); } } diff --git a/whisper.h b/whisper.h index 2983765275a..a3529c84c4b 100644 --- a/whisper.h +++ b/whisper.h @@ -278,15 +278,24 @@ extern "C" { // common decoding parameters: bool suppress_blank; + float temperature; float max_initial_timestamp; + // fallback parameters + float temperature_increment; + float compression_ratio_threshold; + float logprob_threshold; + float no_speech_threshold; + struct { - int dummy; + int best_of; } greedy; struct { - int beam_width; - int n_best; + int beam_size; + + float patience; + float length_penalty; } beam_search; whisper_new_segment_callback new_segment_callback; From 3d723d0b82b84cf998efb21b63a713ce73cfc3a9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Jan 2023 21:44:05 +0200 Subject: [PATCH 09/23] whisper : wip decoding parameters + strategies (part 2) --- whisper.cpp | 589 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 371 insertions(+), 218 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 6759cd2c0f0..4e11e676416 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -418,14 +418,31 @@ struct whisper_model { struct whisper_sequence { std::vector tokens; + + // the accumulated transcription in the current interation (used to truncate the tokens array) + int result_len; + + double sum_logprobs; }; +// TAGS: WHISPER_DECODER_INIT struct whisper_decoder { whisper_kv_cache kv_self; whisper_sequence sequence; - std::vector prompt; + int n_past; + int seek_delta; + + bool failed; // has the current segment failed to decode? + bool completed; // has the decoder completed the current segment? + bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? + + std::vector tokens; + + // new token logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector logits; + std::vector logprobs; }; struct whisper_context { @@ -447,23 +464,17 @@ struct whisper_context { whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; - int selected_decoder_id = 0; - // memory buffers used by encode / decode contexts std::vector buf_compute; std::vector buf_compute_layer; - // decode output + // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector probs; std::vector logits; std::vector result_all; std::vector prompt_past; - // used to avoid allocations - std::vector work_logits; - std::vector work_logprobs; - std::vector> probs_id; // [EXPERIMENTAL] token-level timestamps data @@ -651,8 +662,6 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); - wctx.selected_decoder_id = 0; - if (!kv_cache_init(model.hparams, scale*MEM_REQ_KV_SELF.at(model.type), wctx.decoders[0].kv_self, wctx.wtype, model.hparams.n_text_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); return false; @@ -760,13 +769,13 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx); - wctx.work_logits.reserve (vocab.n_vocab); - wctx.work_logprobs.reserve(vocab.n_vocab); - wctx.probs_id.reserve(n_vocab); + // TAGS: WHISPER_DECODER_INIT wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); - wctx.decoders[0].prompt.reserve(model.hparams.n_text_ctx); + + wctx.decoders[0].logits.reserve (vocab.n_vocab); + wctx.decoders[0].logprobs.reserve(vocab.n_vocab); } size_t ctx_size = 0; @@ -1586,18 +1595,16 @@ static bool whisper_encode( // - n_past: number of past tokens to prefix the prompt with // static bool whisper_decode( - whisper_context & wctx, - const int n_threads, - const whisper_token * tokens, - const int n_tokens, - const int n_past) { + whisper_context & wctx, + whisper_decoder & decoder, + const int n_threads, + const whisper_token * tokens, + const int n_tokens, + const int n_past) { const auto & model = wctx.model; const auto & hparams = model.hparams; - WHISPER_ASSERT(wctx.selected_decoder_id >= 0); - WHISPER_ASSERT(wctx.selected_decoder_id < WHISPER_MAX_DECODERS); - - auto & kv_self = wctx.decoders[wctx.selected_decoder_id].kv_self; + auto & kv_self = decoder.kv_self; WHISPER_ASSERT(!!kv_self.ctx); @@ -2506,7 +2513,10 @@ int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { const int64_t t_start_us = ggml_time_us(); - if (!whisper_decode(*ctx, n_threads, tokens, n_tokens, n_past)) { + // TODO: add selected_decoder_id to context + const int selected_decoder_id = 0; + + if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], n_threads, tokens, n_tokens, n_past)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } @@ -2516,28 +2526,6 @@ int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, i return 0; } -// TODO: remove -static struct whisper_token_data whisper_sample_best(struct whisper_context * ctx) { - const int64_t t_start_sample_us = ggml_time_us(); - - const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), false, false); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - - return res; -} - -// TODO: remove -static struct whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial) { - const int64_t t_start_sample_us = ggml_time_us(); - - const auto res = whisper_sample_best(*ctx, ctx->probs.data() + (ctx->probs.size() - ctx->vocab.n_vocab), true, is_initial); - - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - - return res; -} - int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { const auto res = tokenize(ctx->vocab, text); @@ -2899,15 +2887,15 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str // forward declarations static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); static void whisper_exp_compute_token_level_timestamps( - struct whisper_context * ctx, - int i_segment, - float thold_pt, - float thold_ptsum); + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum); // wrap the last segment to max_len characters // returns the number of new segments -static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { - auto segment = ctx->result_all.back(); +static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { + auto segment = ctx.result_all.back(); int res = 1; int acc = 0; @@ -2916,34 +2904,34 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { for (int i = 0; i < (int) segment.tokens.size(); i++) { const auto & token = segment.tokens[i]; - if (token.id >= whisper_token_eot(ctx)) { + if (token.id >= whisper_token_eot(&ctx)) { continue; } - const auto txt = whisper_token_to_str(ctx, token.id); + const auto txt = whisper_token_to_str(&ctx, token.id); const int cur = strlen(txt); if (acc + cur > max_len && i > 0) { // split here - ctx->result_all.back().text = std::move(text); - ctx->result_all.back().t1 = token.t0; - ctx->result_all.back().tokens.resize(i); + ctx.result_all.back().text = std::move(text); + ctx.result_all.back().t1 = token.t0; + ctx.result_all.back().tokens.resize(i); - ctx->result_all.push_back({}); - ctx->result_all.back().t0 = token.t0; - ctx->result_all.back().t1 = segment.t1; + ctx.result_all.push_back({}); + ctx.result_all.back().t0 = token.t0; + ctx.result_all.back().t1 = segment.t1; // add tokens [i, end] to the new segment - ctx->result_all.back().tokens.insert( - ctx->result_all.back().tokens.end(), + ctx.result_all.back().tokens.insert( + ctx.result_all.back().tokens.end(), segment.tokens.begin() + i, segment.tokens.end()); acc = 0; text = ""; - segment = ctx->result_all.back(); + segment = ctx.result_all.back(); i = -1; res++; @@ -2953,33 +2941,33 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { } } - ctx->result_all.back().text = std::move(text); + ctx.result_all.back().text = std::move(text); return res; } -static struct whisper_token_data whisper_sample_next_token( - struct whisper_context * ctx, - struct whisper_full_params params, - double & sum_logprobs, - const std::vector & prompt, - const std::vector & tokens_cur) { - struct whisper_token_data result = {}; - - const auto & vocab = ctx->vocab; +// process the logits for the selected decoder +// - applyies logit filters +// - computes logprobs +static void whisper_process_logits( + struct whisper_context & ctx, + struct whisper_decoder & decoder, + struct whisper_full_params params) { + const auto & vocab = ctx.vocab; + const auto & tokens_cur = decoder.sequence.tokens; const bool is_initial = tokens_cur.size() == 0; const int n_logits = vocab.id_to_token.size(); - WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab); + WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); // extract the logits for the last token - // we will be mutating and therefore we don't want to use the ctx->logits buffer directly - auto & logits = ctx->work_logits; - auto & logprobs = ctx->work_logprobs; + // we will be mutating and therefore we don't want to use the ctx.logits buffer directly + auto & logits = decoder.logits; + auto & logprobs = decoder.logprobs; { logits.resize(n_logits); - memcpy(logits.data(), ctx->logits.data() + (ctx->logits.size() - n_logits), n_logits*sizeof(float)); + memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); // will be populated a bit later logprobs.resize(n_logits); @@ -3023,7 +3011,7 @@ static struct whisper_token_data whisper_sample_next_token( // the initial timestamp cannot be larger than max_initial_timestamp // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 if (is_initial && params.max_initial_timestamp > 0.0f) { - const float precision = float(WHISPER_CHUNK_SIZE)/ctx->model.hparams.n_audio_ctx; + const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; const int tid0 = std::round(params.max_initial_timestamp/precision); for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) { @@ -3091,24 +3079,24 @@ static struct whisper_token_data whisper_sample_next_token( printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); #endif - switch (params.strategy) { - case WHISPER_SAMPLING_GREEDY: - { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274 - // TODO: implement - result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); - } break; - case WHISPER_SAMPLING_BEAM_SEARCH: - { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364 - // TODO: implement - } break; - } + //switch (params.strategy) { + // case WHISPER_SAMPLING_GREEDY: + // { + // // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274 + // // TODO: implement + // result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); + // } break; + // case WHISPER_SAMPLING_BEAM_SEARCH: + // { + // // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364 + // // TODO: implement + // } break; + //} - sum_logprobs += logprobs[result.id]; - printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1)); + //sum_logprobs += logprobs[result.id]; + //printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1)); - return result; + //return result; } int whisper_full( @@ -3191,6 +3179,7 @@ int whisper_full( }; for (int i = 1; i < n_decoders; i++) { + // TAGS: WHISPER_DECODER_INIT if (ctx->decoders[i].kv_self.ctx == nullptr) { ctx->decoders[i].kv_self = ctx->decoders[0].kv_self; if (!kv_cache_reinit(ctx->decoders[i].kv_self)) { @@ -3201,7 +3190,9 @@ int whisper_full( fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i); ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); - ctx->decoders[i].prompt.reserve(ctx->decoders[0].prompt.capacity()); + + ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab); + ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab); } } @@ -3242,14 +3233,12 @@ int whisper_full( int progress_prev = 0; int progress_step = 5; - std::vector tokens_cur; - tokens_cur.reserve(whisper_n_text_ctx(ctx)); + int seek = seek_start; std::vector prompt; prompt.reserve(whisper_n_text_ctx(ctx)); // main loop - int seek = seek_start; while (true) { const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); while (progress_cur >= progress_prev + progress_step) { @@ -3283,133 +3272,297 @@ int whisper_full( prompt_past.clear(); } - prompt.clear(); + //prompt.clear(); - // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty()) { - int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + //// if we have already generated some text, use it as a prompt to condition the next generation + //if (!prompt_past.empty()) { + // int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); - prompt = { whisper_token_prev(ctx) }; - prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + // prompt = { whisper_token_prev(ctx) }; + // prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); - prompt_past.clear(); - prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); - } + // prompt_past.clear(); + // prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); + //} + + //prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + + //// print the prompt + ////printf("\n\n"); + ////for (int i = 0; i < prompt.size(); i++) { + //// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); + ////} + ////printf("\n\n"); + + //int n_past = 0; + //int seek_delta = 100*WHISPER_CHUNK_SIZE; + + //// the accumulated transcription in the current interation + //int result_len = 0; + //tokens_cur.clear(); + + //bool failed = false; // has the current segment failed to decode? + //bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? + + //double sum_logprobs = 0.0; + + //for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + // if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { + // fprintf(stderr, "%s: failed to decode\n", __func__); + // return -7; + // } + + // n_past += prompt.size(); + // prompt.clear(); + + // // sample the next token based on the selected decoding strategy + parameters + // // also, update the sliding window position based on the sampled timestamp tokens + // { + // const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur); + + // // timestamp token - update sliding window + // if (token.id > whisper_token_beg(ctx)) { + // const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // // do not allow to go back in time + // if (has_ts && seek_delta > seek_delta_new && result_len < i) { + // break; + // } + + // seek_delta = seek_delta_new; + // result_len = i + 1; + // has_ts = true; + // } + + // // add it to the context + // prompt.push_back(token.id); + // tokens_cur.push_back(token); + + // //{ + // // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; + // // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + // //} + + // // end of segment + // if (token.id == whisper_token_eot(ctx) || // end of text token + // (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + // (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + // ) { + // if (result_len == 0) { + // if (seek + seek_delta + 100 >= seek_end) { + // result_len = i + 1; + // } else { + // failed = true; + // break; + // } + // } + + // if (params.single_segment) { + // result_len = i + 1; + // seek_delta = 100*WHISPER_CHUNK_SIZE; + // } + + // break; + // } + + // // TESTS: if no tensors are loaded, it means we are running tests + // if (ctx->model.n_loaded == 0) { + // seek_delta = 100*WHISPER_CHUNK_SIZE; + // break; + // } + // } + + // // sometimes, the decoding can get stuck in a repetition loop + // // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + // if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + // failed = true; + // break; + // } + //} + + //if (failed) { + // // when we fail to sample timestamp token, retry by clearing the past prompt + // // if it fails again, then we advance the window by 1 second + // if (!prompt_past.empty()) { + // prompt_past.clear(); + // } else { + // fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__); + // seek += 100; + // } + // continue; + //} - prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + //// shrink down to result_len + //tokens_cur.resize(result_len); - // print the prompt - //printf("\n\n"); - //for (int i = 0; i < prompt.size(); i++) { - // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); + //for (const auto & r : tokens_cur) { + // prompt_past.push_back(r.id); //} - //printf("\n\n"); - int n_past = 0; - int seek_delta = 100*WHISPER_CHUNK_SIZE; + for (int it = 0; it < (int) temperatures.size(); ++it) { + const float t_cur = temperatures[it]; - // the accumulated transcription in the current interation - int result_len = 0; - tokens_cur.clear(); + int n_decoders_cur = 1; + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } else { + n_decoders_cur = params.beam_search.beam_size; + } + } break; + }; - bool failed = false; // has the current segment failed to decode? - bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? + fprintf(stderr, "\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); - double sum_logprobs = 0.0; + if (t_cur > 0.5) { + prompt_past.clear(); - for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { - if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { - fprintf(stderr, "%s: failed to decode\n", __func__); - return -7; + fprintf(stderr, "%s: clearing prompt_past\n", __func__); } - n_past += prompt.size(); - prompt.clear(); + // TAGS: WHISPER_DECODER_INIT + for (int i = 0; i < n_decoders_cur; ++i) { + auto & decoder = ctx->decoders[i]; - // sample the next token based on the selected decoding strategy + parameters - // also, update the sliding window position based on the sampled timestamp tokens - { - const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur); + decoder.sequence.tokens.clear(); + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs = 0.0; - // timestamp token - update sliding window - if (token.id > whisper_token_beg(ctx)) { - const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + decoder.n_past = 0; + decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; - // do not allow to go back in time - if (has_ts && seek_delta > seek_delta_new && result_len < i) { - break; - } + decoder.failed = false; + decoder.completed = false; + decoder.has_ts = false; + } - seek_delta = seek_delta_new; - result_len = i + 1; - has_ts = true; + // init prompt and kv cache for the current iteration + // run whisper_decoder() only for decoder 0 and copy the results for the other decoders + { + prompt.clear(); + + // if we have already generated some text, use it as a prompt to condition the next generation + if (!prompt_past.empty()) { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + + prompt = { whisper_token_prev(ctx) }; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); } - // add it to the context - prompt.push_back(token.id); - tokens_cur.push_back(token); + // init new transcription with sot, language (opt) and task tokens + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - //{ - // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + // print the prompt + //printf("\n\n"); + //for (int i = 0; i < prompt.size(); i++) { + // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); //} + //printf("\n\n"); - // end of segment - if (token.id == whisper_token_eot(ctx) || // end of text token - (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached - (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached - ) { - if (result_len == 0) { - if (seek + seek_delta + 100 >= seek_end) { - result_len = i + 1; - } else { - failed = true; - break; - } - } + if (whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0) != 0) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } - if (params.single_segment) { - result_len = i + 1; - seek_delta = 100*WHISPER_CHUNK_SIZE; - } + whisper_process_logits(*ctx, ctx->decoders[0], params); - break; + for (int i = 1; i < n_decoders_cur; ++i) { + auto & decoder = ctx->decoders[i]; + + memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size()); + + decoder.n_past += prompt.size(); + + memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); } + } + + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + bool completed = true; + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + for (int i = 0; i < n_decoders_cur; ++i) { + auto & decoder = ctx->decoders[i]; + + if (decoder.completed || decoder.failed) { + continue; + } + + if (t_cur < 1e-6f) { + // select top token + } else { + } + } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + // TODO: .. + } break; + }; - // TESTS: if no tensors are loaded, it means we are running tests - if (ctx->model.n_loaded == 0) { - seek_delta = 100*WHISPER_CHUNK_SIZE; + if (completed) { break; } + + for (int i = 0; i < n_decoders_cur; ++i) { + auto & decoder = ctx->decoders[i]; + + if (decoder.failed || decoder.completed) { + continue; + } + + decoder.tokens.resize(1); + decoder.tokens[0] = decoder.sequence.tokens.back().id; + + if (whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens.data(), decoder.tokens.size(), decoder.n_past) != 0) { + fprintf(stderr, "%s: failed to decode\n", __func__); + return -7; + } + + whisper_process_logits(*ctx, decoder, params); + + ++decoder.n_past; + } } - // sometimes, the decoding can get stuck in a repetition loop - // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy - if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { - failed = true; - break; + // TODO: rank the resulting sequences and select the best one + { } - } - if (failed) { - // when we fail to sample timestamp token, retry by clearing the past prompt - // if it fails again, then we advance the window by 1 second - if (!prompt_past.empty()) { - prompt_past.clear(); - } else { - fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__); - seek += 100; + bool success = true; + + // TODO: implement logprob threshold + compression threshold + { } - continue; - } - // shrink down to result_len - tokens_cur.resize(result_len); + if (success) { + break; + } - for (const auto & r : tokens_cur) { - prompt_past.push_back(r.id); + fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } + // TODO + const int seek_delta = 0; + whisper_sequence seq_best; + + const auto & tokens_cur = seq_best.tokens; + + // TODO: update prompt_past to keep only the last whisper_n_text_ctx(ctx)/2 tokens + //prompt_past.clear(); + //prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); + // store the text from this iteration if (!tokens_cur.empty()) { int i0 = 0; @@ -3450,10 +3603,10 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(ctx, params.max_len); + n_new = whisper_wrap_segment(*ctx, params.max_len); } } if (params.new_segment_callback) { @@ -3494,10 +3647,10 @@ int whisper_full( if (params.token_timestamps) { whisper_exp_compute_token_level_timestamps( - ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); if (params.max_len > 0) { - n_new = whisper_wrap_segment(ctx, params.max_len); + n_new = whisper_wrap_segment(*ctx, params.max_len); } } if (params.new_segment_callback) { @@ -3535,9 +3688,6 @@ int whisper_full_parallel( ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); - ctx_p.work_logits.reserve (ctx_p.vocab.n_vocab); - ctx_p.work_logprobs.reserve(ctx_p.vocab.n_vocab); - ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab); if (!kv_cache_reinit(ctx_p.kv_cross)) { @@ -3545,6 +3695,7 @@ int whisper_full_parallel( return false; } + // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < WHISPER_MAX_DECODERS; ++j) { if (ctx_p.decoders[j].kv_self.ctx && !kv_cache_reinit(ctx_p.decoders[j].kv_self)) { fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d, processor %d\n", __func__, j, i); @@ -3552,7 +3703,9 @@ int whisper_full_parallel( } ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); - ctx_p.decoders[j].prompt.reserve(ctx_p.model.hparams.n_text_ctx); + + ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); + ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); } } @@ -3747,14 +3900,14 @@ static std::vector get_signal_energy(const float * signal, int n_samples, } static void whisper_exp_compute_token_level_timestamps( - struct whisper_context * ctx, - int i_segment, - float thold_pt, - float thold_ptsum) { - auto & segment = ctx->result_all[i_segment]; + struct whisper_context & ctx, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = ctx.result_all[i_segment]; auto & tokens = segment.tokens; - const int n_samples = ctx->energy.size(); + const int n_samples = ctx.energy.size(); if (n_samples == 0) { fprintf(stderr, "%s: no signal data available\n", __func__); @@ -3777,28 +3930,28 @@ static void whisper_exp_compute_token_level_timestamps( return; } - auto & t_beg = ctx->t_beg; - auto & t_last = ctx->t_last; - auto & tid_last = ctx->tid_last; + auto & t_beg = ctx.t_beg; + auto & t_last = ctx.t_last; + auto & tid_last = ctx.tid_last; for (int j = 0; j < n; ++j) { auto & token = tokens[j]; if (j == 0) { - if (token.id == whisper_token_beg(ctx)) { + if (token.id == whisper_token_beg(&ctx)) { tokens[j ].t0 = t0; tokens[j ].t1 = t0; tokens[j + 1].t0 = t0; t_beg = t0; t_last = t0; - tid_last = whisper_token_beg(ctx); + tid_last = whisper_token_beg(&ctx); } else { tokens[j ].t0 = t_last; } } - const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx)); tokens[j].id = token.id; tokens[j].tid = token.tid; @@ -3806,7 +3959,7 @@ static void whisper_exp_compute_token_level_timestamps( tokens[j].pt = token.pt; tokens[j].ptsum = token.ptsum; - tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id)); + tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { if (j > 0) { @@ -3885,7 +4038,7 @@ static void whisper_exp_compute_token_level_timestamps( const int hw = WHISPER_SAMPLE_RATE/8; for (int j = 0; j < n; j++) { - if (tokens[j].id >= whisper_token_eot(ctx)) { + if (tokens[j].id >= whisper_token_eot(&ctx)) { continue; } @@ -3900,15 +4053,15 @@ static void whisper_exp_compute_token_level_timestamps( float sum = 0.0f; for (int k = ss0; k < ss1; k++) { - sum += ctx->energy[k]; + sum += ctx.energy[k]; } const float thold = 0.5*sum/ns; { int k = s0; - if (ctx->energy[k] > thold && j > 0) { - while (k > 0 && ctx->energy[k] > thold) { + if (ctx.energy[k] > thold && j > 0) { + while (k > 0 && ctx.energy[k] > thold) { k--; } tokens[j].t0 = sample_to_timestamp(k); @@ -3918,7 +4071,7 @@ static void whisper_exp_compute_token_level_timestamps( s0 = k; } } else { - while (ctx->energy[k] < thold && k < s1) { + while (ctx.energy[k] < thold && k < s1) { k++; } s0 = k; @@ -3928,8 +4081,8 @@ static void whisper_exp_compute_token_level_timestamps( { int k = s1; - if (ctx->energy[k] > thold) { - while (k < n_samples - 1 && ctx->energy[k] > thold) { + if (ctx.energy[k] > thold) { + while (k < n_samples - 1 && ctx.energy[k] > thold) { k++; } tokens[j].t1 = sample_to_timestamp(k); @@ -3939,7 +4092,7 @@ static void whisper_exp_compute_token_level_timestamps( s1 = k; } } else { - while (ctx->energy[k] < thold && k > s0) { + while (ctx.energy[k] < thold && k > s0) { k--; } s1 = k; From 116dd67a158d328211bf2bd96b1e903bff8a4b8e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 11:09:23 +0200 Subject: [PATCH 10/23] whisper : wip decoding parameters + strategies (part 3) --- whisper.cpp | 664 +++++++++++++++++++++++++--------------------------- whisper.h | 3 +- 2 files changed, 327 insertions(+), 340 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 4e11e676416..ed6c52dd455 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -422,7 +422,9 @@ struct whisper_sequence { // the accumulated transcription in the current interation (used to truncate the tokens array) int result_len; - double sum_logprobs; + double sum_logprobs; // the sum of the log probabilities of the tokens + double avg_logprobs; // the average log probability of the tokens + double score; // likelihood rank score }; // TAGS: WHISPER_DECODER_INIT @@ -438,11 +440,12 @@ struct whisper_decoder { bool completed; // has the decoder completed the current segment? bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? - std::vector tokens; - - // new token logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector probs; std::vector logits; std::vector logprobs; + + std::vector tokens_tmp; // used for whisper_decode calls }; struct whisper_context { @@ -774,6 +777,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // TAGS: WHISPER_DECODER_INIT wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); + wctx.decoders[0].probs.reserve (vocab.n_vocab); wctx.decoders[0].logits.reserve (vocab.n_vocab); wctx.decoders[0].logprobs.reserve(vocab.n_vocab); } @@ -1959,99 +1963,6 @@ static bool whisper_decode( return true; } -// the most basic sampling scheme - select the top token -static whisper_token_data whisper_sample_best( - whisper_context & ctx, - const float * probs, - bool force_timestamp, - bool is_initial) { - whisper_token_data result = { - 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, - }; - - const auto & vocab = ctx.vocab; - - const int n_logits = vocab.n_vocab; - - auto & probs_id = ctx.probs_id; - - probs_id.clear(); - for (int i = 0; i < n_logits; i++) { - probs_id.emplace_back(probs[i], i); - } - - { - double sum_ts = 0.0; - double max_ts = -1.0; - double max_tx = -1.0; - - for (int i = 0; i < vocab.token_beg; i++) { - max_tx = std::max(max_tx, probs_id[i].first); - } - - const auto i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg; - const auto i1 = is_initial ? vocab.token_beg + 101 : n_logits; - - // the initial timestamp cannot be larger than 100 - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 - if (is_initial) { - for (int i = i0; i < n_logits; ++ i) { - probs_id[i].first = -INFINITY; - } - } - - for (int i = vocab.token_beg; i < i1; i++) { - sum_ts += probs_id[i].first; - if (probs_id[i].first > max_ts) { - max_ts = probs_id[i].first; - result.tid = probs_id[i].second; - } - } - - // if the probability sum of all timestamp tokens is higher than the max probability of the text tokens - sample a - // timestamp token - if (sum_ts > max_tx || force_timestamp) { - // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 - for (int i = 0; i < vocab.token_beg; i++) { - probs_id[i].first = -INFINITY; - } - } - - result.pt = max_ts/(sum_ts + 1e-10); - result.ptsum = sum_ts; - } - - // find the top K tokens - const int top_k = 4; - - std::partial_sort( - probs_id.begin(), - probs_id.begin() + top_k, probs_id.end(), - [](const std::pair & a, const std::pair & b) { - return a.first > b.first; - }); - - probs_id.resize(top_k); - - //printf("\n"); - //for (int i = 0; i < (int) probs_id.size(); i++) { - // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); - //} - - int res = 0; - while ((probs_id[res].second == vocab.token_sot || - probs_id[res].second == vocab.token_solm || - probs_id[res].second == vocab.token_not) && - res < (int) probs_id.size() - 1) { - res++; - } - - result.id = probs_id[res].second; - result.p = probs_id[res].first; - - return result; -} - // 500 -> 00:05.000 // 6000 -> 01:00.000 static std::string to_timestamp(int64_t t, bool comma = false) { @@ -2793,8 +2704,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.suppress_blank =*/ true, - /*.temperature =*/ 0.0f, - /*.max_initial_timestamp =*/ 1.0f, + /*.temperature =*/ 0.0f, + /*.max_initial_timestamp =*/ 1.0f, + /*.length_penalty =*/ -1.0f, /*.temperature_increment =*/ 0.2f, /*.compression_ratio_threshold =*/ 2.4f, @@ -2809,7 +2721,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.beam_size =*/ -1, /*.patience =*/ -1.0f, - /*.length_penalty =*/ -1.0f, }, /*.new_segment_callback =*/ nullptr, @@ -2853,8 +2764,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.suppress_blank =*/ true, - /*.temperature =*/ 0.0f, - /*.max_initial_timestamp =*/ 1.0f, + /*.temperature =*/ 0.0f, + /*.max_initial_timestamp =*/ 1.0f, + /*.length_penalty =*/ -1.0f, /*.temperature_increment =*/ 0.2f, /*.compression_ratio_threshold =*/ 2.4f, @@ -2869,7 +2781,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.beam_size =*/ 5, /*.patience =*/ -1.0f, - /*.length_penalty =*/ -1.0f, }, /*.new_segment_callback =*/ nullptr, @@ -2963,9 +2874,13 @@ static void whisper_process_logits( // extract the logits for the last token // we will be mutating and therefore we don't want to use the ctx.logits buffer directly + auto & probs = decoder.probs; auto & logits = decoder.logits; auto & logprobs = decoder.logprobs; { + probs.resize(n_logits); + memcpy(probs.data(), ctx.probs.data() + (ctx.probs.size() - n_logits), n_logits*sizeof(float)); + logits.resize(n_logits); memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); @@ -2995,6 +2910,8 @@ static void whisper_process_logits( const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + //fprintf(stderr, "last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + if (last_was_timestamp) { if (penultimate_was_timestamp) { for (int i = vocab.token_beg; i < n_logits; ++ i) { @@ -3099,6 +3016,83 @@ static void whisper_process_logits( //return result; } +// select the most probable token +static whisper_token_data whisper_sample_best( + whisper_context & ctx, + whisper_decoder & decoder) { + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; + + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + result.tid = i; + } + } + + result.pt = max_ts/(sum_ts + 1e-10); + result.ptsum = sum_ts; + } + + for (int i = 0; i < n_logits; ++i) { + // never sample these: + if (i == vocab.token_sot || + i == vocab.token_solm || + i == vocab.token_not) { + continue; + } + + if (result.p < probs[i]) { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } + } + + return result; +} + +// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 +static void whisper_sequence_score( + const struct whisper_full_params & params, + whisper_sequence & sequence) { + WHISPER_ASSERT(sequence.result_len > 0); + + double result = 0.0f; + + for (int i = 0; i < sequence.result_len; ++i) { + result += sequence.tokens[i].plog; + } + + sequence.sum_logprobs = result; + sequence.avg_logprobs = result/sequence.result_len; + + double penalty = sequence.result_len; + + if (params.length_penalty > 0.0f) { + penalty = pow((5.0 + penalty) / 6.0, params.length_penalty); + } + + sequence.score = result/penalty; +} + int whisper_full( struct whisper_context * ctx, struct whisper_full_params params, @@ -3191,6 +3185,7 @@ int whisper_full( ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + ctx->decoders[i].probs.reserve (ctx->vocab.n_vocab); ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab); ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab); } @@ -3261,7 +3256,7 @@ int whisper_full( } // encode audio features starting at offset seek - if (whisper_encode(ctx, seek, params.n_threads) != 0) { + if (!whisper_encode(*ctx, seek, params.n_threads)) { fprintf(stderr, "%s: failed to encode\n", __func__); return -6; } @@ -3272,132 +3267,7 @@ int whisper_full( prompt_past.clear(); } - //prompt.clear(); - - //// if we have already generated some text, use it as a prompt to condition the next generation - //if (!prompt_past.empty()) { - // int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); - - // prompt = { whisper_token_prev(ctx) }; - // prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); - - // prompt_past.clear(); - // prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); - //} - - //prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); - - //// print the prompt - ////printf("\n\n"); - ////for (int i = 0; i < prompt.size(); i++) { - //// printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); - ////} - ////printf("\n\n"); - - //int n_past = 0; - //int seek_delta = 100*WHISPER_CHUNK_SIZE; - - //// the accumulated transcription in the current interation - //int result_len = 0; - //tokens_cur.clear(); - - //bool failed = false; // has the current segment failed to decode? - //bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? - - //double sum_logprobs = 0.0; - - //for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { - // if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { - // fprintf(stderr, "%s: failed to decode\n", __func__); - // return -7; - // } - - // n_past += prompt.size(); - // prompt.clear(); - - // // sample the next token based on the selected decoding strategy + parameters - // // also, update the sliding window position based on the sampled timestamp tokens - // { - // const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur); - - // // timestamp token - update sliding window - // if (token.id > whisper_token_beg(ctx)) { - // const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); - - // // do not allow to go back in time - // if (has_ts && seek_delta > seek_delta_new && result_len < i) { - // break; - // } - - // seek_delta = seek_delta_new; - // result_len = i + 1; - // has_ts = true; - // } - - // // add it to the context - // prompt.push_back(token.id); - // tokens_cur.push_back(token); - - // //{ - // // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - // // printf("%s: %3d %10s %6d %6.3f '%s'\n", __func__, i, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); - // //} - - // // end of segment - // if (token.id == whisper_token_eot(ctx) || // end of text token - // (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached - // (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached - // ) { - // if (result_len == 0) { - // if (seek + seek_delta + 100 >= seek_end) { - // result_len = i + 1; - // } else { - // failed = true; - // break; - // } - // } - - // if (params.single_segment) { - // result_len = i + 1; - // seek_delta = 100*WHISPER_CHUNK_SIZE; - // } - - // break; - // } - - // // TESTS: if no tensors are loaded, it means we are running tests - // if (ctx->model.n_loaded == 0) { - // seek_delta = 100*WHISPER_CHUNK_SIZE; - // break; - // } - // } - - // // sometimes, the decoding can get stuck in a repetition loop - // // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy - // if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { - // failed = true; - // break; - // } - //} - - //if (failed) { - // // when we fail to sample timestamp token, retry by clearing the past prompt - // // if it fails again, then we advance the window by 1 second - // if (!prompt_past.empty()) { - // prompt_past.clear(); - // } else { - // fprintf(stderr, "\n%s: failed to generate timestamp token - skipping one second\n\n", __func__); - // seek += 100; - // } - // continue; - //} - - //// shrink down to result_len - //tokens_cur.resize(result_len); - - //for (const auto & r : tokens_cur) { - // prompt_past.push_back(r.id); - //} + int best_decoder_id = 0; for (int it = 0; it < (int) temperatures.size(); ++it) { const float t_cur = temperatures[it]; @@ -3429,12 +3299,14 @@ int whisper_full( } // TAGS: WHISPER_DECODER_INIT - for (int i = 0; i < n_decoders_cur; ++i) { - auto & decoder = ctx->decoders[i]; + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; decoder.sequence.tokens.clear(); decoder.sequence.result_len = 0; decoder.sequence.sum_logprobs = 0.0; + decoder.sequence.avg_logprobs = 0.0; + decoder.sequence.score = 0.0; decoder.n_past = 0; decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; @@ -3467,65 +3339,148 @@ int whisper_full( //} //printf("\n\n"); - if (whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0) != 0) { + if (!whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } whisper_process_logits(*ctx, ctx->decoders[0], params); - for (int i = 1; i < n_decoders_cur; ++i) { - auto & decoder = ctx->decoders[i]; + for (int j = 1; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size()); decoder.n_past += prompt.size(); + memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); } } for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { - bool completed = true; - - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - for (int i = 0; i < n_decoders_cur; ++i) { - auto & decoder = ctx->decoders[i]; + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; - if (decoder.completed || decoder.failed) { - continue; - } + if (decoder.completed || decoder.failed) { + continue; + } + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { if (t_cur < 1e-6f) { - // select top token + decoder.sequence.tokens.push_back(whisper_sample_best(*ctx, decoder)); } else { } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + // TODO: .. + } break; + }; + + auto & has_ts = decoder.has_ts; + auto & failed = decoder.failed; + auto & completed = decoder.completed; + auto & seek_delta = decoder.seek_delta; + auto & result_len = decoder.sequence.result_len; + + { + const auto & token = decoder.sequence.tokens.back(); + + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (has_ts && seek_delta > seek_delta_new && result_len < i) { + failed = true; // TODO: maybe this is not a failure ? + break; } - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + { - // TODO: .. - } break; - }; + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; + printf("%s: %3d, decoder = %d, %10s %6d %6.3f '%s'\n", __func__, i, j, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + } - if (completed) { - break; + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) { + if (result_len == 0) { + if (seek + seek_delta + 100 >= seek_end) { + result_len = i + 1; + } else { + failed = true; + break; + } + } + + if (params.single_segment) { + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; + } + + completed = true; + break; + } + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) { + seek_delta = 100*WHISPER_CHUNK_SIZE; + completed = true; + break; + } + } + + // sometimes, the decoding can get stuck in a repetition loop + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + failed = true; + break; + } + } + + // check if all decoders have finished (i.e. completed or failed) + { + bool completed_all = true; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + completed_all = false; + } + + if (completed_all) { + break; + } } - for (int i = 0; i < n_decoders_cur; ++i) { - auto & decoder = ctx->decoders[i]; + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; if (decoder.failed || decoder.completed) { continue; } - decoder.tokens.resize(1); - decoder.tokens[0] = decoder.sequence.tokens.back().id; + decoder.tokens_tmp.resize(1); + decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; + + //fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); - if (whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens.data(), decoder.tokens.size(), decoder.n_past) != 0) { + if (!whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } @@ -3536,14 +3491,36 @@ int whisper_full( } } - // TODO: rank the resulting sequences and select the best one + // rank the resulting sequences and select the best one { + double best_score = -1e9; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.failed) { + continue; + } + + whisper_sequence_score(params, ctx->decoders[j].sequence); + + if (best_score < decoder.sequence.score) { + best_score = decoder.sequence.score; + best_decoder_id = j; + } + } } bool success = true; - // TODO: implement logprob threshold + compression threshold + // implement logprob threshold + // TODO: implement compression threshold { + auto & decoder = ctx->decoders[best_decoder_id]; + + if (decoder.sequence.avg_logprobs < params.logprob_threshold) { + success = false; + } } if (success) { @@ -3553,113 +3530,121 @@ int whisper_full( fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } - // TODO - const int seek_delta = 0; - whisper_sequence seq_best; + { + const auto & best_decoder = ctx->decoders[best_decoder_id]; - const auto & tokens_cur = seq_best.tokens; + const auto seek_delta = best_decoder.seek_delta; + const auto result_len = best_decoder.sequence.result_len; - // TODO: update prompt_past to keep only the last whisper_n_text_ctx(ctx)/2 tokens - //prompt_past.clear(); - //prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); + const auto & tokens_cur = best_decoder.sequence.tokens; - // store the text from this iteration - if (!tokens_cur.empty()) { - int i0 = 0; - auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + // update prompt_past + prompt_past.clear(); + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); - std::string text; + for (int i = 0; i < result_len; ++i) { + prompt_past.push_back(tokens_cur[i].id); + } - for (int i = 0; i < (int) tokens_cur.size(); i++) { - //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, - // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, - // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + // store the text from this iteration + if (!tokens_cur.empty()) { + int i0 = 0; + auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); - if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { - } else { - text += whisper_token_to_str(ctx, tokens_cur[i].id); - } - if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { - const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); - if (!text.empty()) { - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; - - if (params.print_realtime) { - if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { - printf("%s", text.c_str()); - fflush(stdout); + std::string text; + + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special == false && tokens_cur[i].id >= whisper_token_eot(ctx)) { + } else { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { + const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } } - } - result_all.push_back({ tt0, tt1, text, {} }); - for (int j = i0; j <= i; j++) { - result_all.back().tokens.push_back(tokens_cur[j]); - } + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } - int n_new = 1; + int n_new = 1; - if (params.token_timestamps) { - whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); - if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len); + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } - if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); + text = ""; + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + i++; } + i--; + t0 = t1; + i0 = i + 1; } - text = ""; - while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { - i++; - } - i--; - t0 = t1; - i0 = i + 1; } - } - if (!text.empty()) { - const auto t1 = seek + seek_delta; + if (!text.empty()) { + const auto t1 = seek + seek_delta; - const auto tt0 = params.speed_up ? 2*t0 : t0; - const auto tt1 = params.speed_up ? 2*t1 : t1; + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; - if (params.print_realtime) { - if (params.print_timestamps) { - printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); - } else { - printf("%s", text.c_str()); - fflush(stdout); + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } } - } - result_all.push_back({ tt0, tt1, text, {} }); - for (int j = i0; j < (int) tokens_cur.size(); j++) { - result_all.back().tokens.push_back(tokens_cur[j]); - } + result_all.push_back({ tt0, tt1, text, {} }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } - int n_new = 1; + int n_new = 1; - if (params.token_timestamps) { - whisper_exp_compute_token_level_timestamps( - *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); - if (params.max_len > 0) { - n_new = whisper_wrap_segment(*ctx, params.max_len); + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, params.max_len); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } - } - if (params.new_segment_callback) { - params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } - } - seek += seek_delta; + // update audio window + seek += seek_delta; + } } return 0; @@ -3704,6 +3689,7 @@ int whisper_full_parallel( ctx_p.decoders[j].sequence.tokens.reserve(ctx_p.model.hparams.n_text_ctx); + ctx_p.decoders[j].probs.reserve (ctx_p.vocab.n_vocab); ctx_p.decoders[j].logits.reserve (ctx_p.vocab.n_vocab); ctx_p.decoders[j].logprobs.reserve(ctx_p.vocab.n_vocab); } diff --git a/whisper.h b/whisper.h index a3529c84c4b..906c7c0ee68 100644 --- a/whisper.h +++ b/whisper.h @@ -74,6 +74,7 @@ extern "C" { whisper_token tid; // forced timestamp token id float p; // probability of the token + float plog; // log probability of the token float pt; // probability of the timestamp token float ptsum; // sum of probabilities of all timestamp tokens @@ -280,6 +281,7 @@ extern "C" { float temperature; float max_initial_timestamp; + float length_penalty; // fallback parameters float temperature_increment; @@ -295,7 +297,6 @@ extern "C" { int beam_size; float patience; - float length_penalty; } beam_search; whisper_new_segment_callback new_segment_callback; From bd6e70b512348ee5ead41cac3aa4c73e3db43150 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 12:38:13 +0200 Subject: [PATCH 11/23] whisper : wip decoding parameters + strategies (part 4) --- whisper.cpp | 286 +++++++++++++++++++++++++++------------------------- whisper.h | 6 -- 2 files changed, 151 insertions(+), 141 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index ed6c52dd455..6a1da95c013 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -472,13 +472,12 @@ struct whisper_context { std::vector buf_compute_layer; // decode output (2-dimensional array: [n_tokens][n_vocab]) - std::vector probs; std::vector logits; std::vector result_all; std::vector prompt_past; - std::vector> probs_id; + std::vector> logits_id; // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; @@ -575,6 +574,10 @@ static void kv_cache_free(struct whisper_kv_cache & cache) { static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { fprintf(stderr, "%s: loading model\n", __func__); + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + auto & model = wctx.model; auto & vocab = wctx.vocab; @@ -770,9 +773,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); - wctx.probs.reserve (vocab.n_vocab*model.hparams.n_text_ctx); - wctx.probs_id.reserve(n_vocab); + wctx.logits_id.reserve(n_vocab); // TAGS: WHISPER_DECODER_INIT wctx.decoders[0].sequence.tokens.reserve(model.hparams.n_text_ctx); @@ -1178,6 +1180,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } + wctx.t_load_us = ggml_time_us() - t_start_us; + return true; } @@ -1191,9 +1195,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // - mel_offset: offset in the mel spectrogram (i.e. audio offset) // static bool whisper_encode( - whisper_context & wctx, - const int n_threads, - const int mel_offset) { + whisper_context & wctx, + const int mel_offset, + const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + const auto & model = wctx.model; const auto & mel_inp = wctx.mel; const auto & hparams = model.hparams; @@ -1585,6 +1591,8 @@ static bool whisper_encode( ggml_free(ctx0); + wctx.t_encode_us += ggml_time_us() - t_start_us; + return true; } @@ -1601,10 +1609,12 @@ static bool whisper_encode( static bool whisper_decode( whisper_context & wctx, whisper_decoder & decoder, - const int n_threads, const whisper_token * tokens, const int n_tokens, - const int n_past) { + const int n_past, + const int n_threads) { + const int64_t t_start_us = ggml_time_us(); + const auto & model = wctx.model; const auto & hparams = model.hparams; @@ -1613,7 +1623,6 @@ static bool whisper_decode( WHISPER_ASSERT(!!kv_self.ctx); auto & logits_out = wctx.logits; - auto & probs_out = wctx.probs; const int n_vocab = hparams.n_vocab; @@ -1625,6 +1634,8 @@ static bool whisper_decode( const int N = n_tokens; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; + //fprintf(stderr, "n_past = %d, N = %d, M = %d, n_ctx = %d\n", n_past, N, M, n_ctx); + struct ggml_init_params params; params.mem_size = wctx.buf_compute.size(); params.mem_buffer = wctx.buf_compute.data(); @@ -1933,25 +1944,18 @@ static bool whisper_decode( struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); - // logits -> probs - cur = ggml_dup(ctx0, logits); - cur = ggml_soft_max(ctx0, cur); // in-place - // run the computation { struct ggml_cgraph gf = {}; gf.n_threads = n_threads; - ggml_build_forward_expand(&gf, cur); + ggml_build_forward_expand(&gf, logits); ggml_graph_compute (ctx0, &gf); } logits_out.resize(N*n_vocab); memcpy(logits_out.data(), ggml_get_data(logits), sizeof(float)*N*n_vocab); - probs_out.resize(N*n_vocab); - memcpy(probs_out.data(), ggml_get_data(cur), sizeof(float)*N*n_vocab); - if (N > 1) { //const float mem_per_token = ggml_used_mem(ctx0)/1024.0/1024.0/N; //printf("%s: used_mem = %f MB / %f per token\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, mem_per_token); @@ -1960,6 +1964,8 @@ static bool whisper_decode( ggml_free(ctx0); + wctx.t_decode_us += ggml_time_us() - t_start_us; + return true; } @@ -2062,16 +2068,18 @@ static void fft(const std::vector & in, std::vector & out) { // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 static bool log_mel_spectrogram( - const float * samples, - const int n_samples, - const int /*sample_rate*/, - const int fft_size, - const int fft_step, - const int n_mel, - const int n_threads, - const whisper_filters & filters, - const bool speed_up, - whisper_mel & mel) { + whisper_context & wctx, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int fft_size, + const int fft_step, + const int n_mel, + const int n_threads, + const whisper_filters & filters, + const bool speed_up, + whisper_mel & mel) { + const int64_t t_start_us = ggml_time_us(); // Hanning window std::vector hann; @@ -2180,6 +2188,8 @@ static bool log_mel_spectrogram( mel.data[i] = (mel.data[i] + 4.0)/4.0; } + wctx.t_mel_us += ggml_time_us() - t_start_us; + return true; } @@ -2324,10 +2334,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) { whisper_context * ctx = new whisper_context; - const int64_t t_start_us = ggml_time_us(); - - ctx->t_start_us = t_start_us; - if (!whisper_model_load(loader, *ctx)) { loader->close(loader->context); fprintf(stderr, "%s: failed to load model\n", __func__); @@ -2335,8 +2341,6 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) { return nullptr; } - ctx->t_load_us = ggml_time_us() - t_start_us; - loader->close(loader->context); return ctx; @@ -2363,29 +2367,21 @@ void whisper_free(struct whisper_context * ctx) { } int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - const int64_t t_start_us = ggml_time_us(); - - if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { + if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, ctx->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } - ctx->t_mel_us = ggml_time_us() - t_start_us; - return 0; } // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { - const int64_t t_start_us = ggml_time_us(); - - if (!log_mel_spectrogram(samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { + if (!log_mel_spectrogram(*ctx, samples, n_samples, WHISPER_SAMPLE_RATE, 2*WHISPER_N_FFT, 2*WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, true, ctx->mel)) { fprintf(stderr, "%s: failed to compute mel spectrogram\n", __func__); return -1; } - ctx->t_mel_us = ggml_time_us() - t_start_us; - return 0; } @@ -2409,31 +2405,23 @@ int whisper_set_mel( } int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { - const int64_t t_start_us = ggml_time_us(); - - if (!whisper_encode(*ctx, n_threads, offset)) { + if (!whisper_encode(*ctx, offset, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return -1; } - ctx->t_encode_us += ggml_time_us() - t_start_us; - return 0; } int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { - const int64_t t_start_us = ggml_time_us(); - // TODO: add selected_decoder_id to context const int selected_decoder_id = 0; - if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], n_threads, tokens, n_tokens, n_past)) { + if (!whisper_decode(*ctx, ctx->decoders[selected_decoder_id], tokens, n_tokens, n_past, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } - ctx->t_decode_us += ggml_time_us() - t_start_us; - return 0; } @@ -2517,36 +2505,39 @@ int whisper_lang_auto_detect( return -7; } - auto & probs_id = ctx->probs_id; - probs_id.clear(); + auto & logits_id = ctx->logits_id; + logits_id.clear(); for (const auto & kv : g_lang) { const auto token_lang = whisper_token_lang(ctx, kv.second.first); - probs_id.emplace_back(ctx->probs[token_lang], kv.second.first); + logits_id.emplace_back(ctx->logits[token_lang], kv.second.first); } // sort descending { - using pair_type = std::remove_reference::type::value_type; - std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + using pair_type = std::remove_reference::type::value_type; + std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) { return a.first > b.first; }); } // softmax { - float sum = 0; - for (const auto & kv : probs_id) { - sum += exp(kv.first); + const auto max = logits_id[0].first; + + double sum = 0.0f; + for (auto & kv : logits_id) { + kv.first = exp(kv.first - max); + sum += kv.first; } - for (auto & kv : probs_id) { - kv.first = exp(kv.first) / sum; + for (auto & kv : logits_id) { + kv.first /= sum; } } { - for (const auto & prob : probs_id) { + for (const auto & prob : logits_id) { if (lang_probs) { lang_probs[prob.second] = prob.first; } @@ -2555,7 +2546,7 @@ int whisper_lang_auto_detect( } } - return probs_id[0].second; + return logits_id[0].second; } int whisper_n_len(struct whisper_context * ctx) { @@ -2582,10 +2573,6 @@ float * whisper_get_logits(struct whisper_context * ctx) { return ctx->logits.data(); } -float * whisper_get_probs(struct whisper_context * ctx) { - return ctx->probs.data(); -} - const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { return ctx->vocab.id_to_token.at(token).c_str(); } @@ -2878,13 +2865,11 @@ static void whisper_process_logits( auto & logits = decoder.logits; auto & logprobs = decoder.logprobs; { - probs.resize(n_logits); - memcpy(probs.data(), ctx.probs.data() + (ctx.probs.size() - n_logits), n_logits*sizeof(float)); - logits.resize(n_logits); memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); // will be populated a bit later + probs.resize(n_logits); logprobs.resize(n_logits); } @@ -2904,6 +2889,10 @@ static void whisper_process_logits( // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 logits[vocab.token_not] = -INFINITY; + // suppress sot and solm tokens + logits[vocab.token_sot] = -INFINITY; + logits[vocab.token_solm] = -INFINITY; + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 { @@ -2914,11 +2903,11 @@ static void whisper_process_logits( if (last_was_timestamp) { if (penultimate_was_timestamp) { - for (int i = vocab.token_beg; i < n_logits; ++ i) { + for (int i = vocab.token_beg; i < n_logits; ++i) { logits[i] = -INFINITY; } } else { - for (int i = 0; i < vocab.token_eot; ++ i) { + for (int i = 0; i < vocab.token_eot; ++i) { logits[i] = -INFINITY; } } @@ -2931,7 +2920,7 @@ static void whisper_process_logits( const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; const int tid0 = std::round(params.max_initial_timestamp/precision); - for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++ i) { + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { logits[i] = -INFINITY; } } @@ -2940,12 +2929,19 @@ static void whisper_process_logits( { const float logit_max = *std::max_element(logits.begin(), logits.end()); float logsumexp = 0.0f; - for (int i = 0; i < n_logits; ++ i) { - logsumexp += expf(logits[i] - logit_max); + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } } logsumexp = logf(logsumexp) + logit_max; - for (int i = 0; i < n_logits; ++ i) { - logprobs[i] = logits[i] - logsumexp; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } } } @@ -2957,29 +2953,48 @@ static void whisper_process_logits( { float logsumexp = 0.0f; const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); - for (int i = vocab.token_beg; i < n_logits; ++ i) { - logsumexp += expf(logprobs[i] - logprob_max); + for (int i = vocab.token_beg; i < n_logits; ++i) { + if (logprobs[i] > -INFINITY) { + logsumexp += expf(logprobs[i] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; } - logsumexp = logf(logsumexp) + logprob_max; - timestamp_logprob = logsumexp; } const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); + //fprintf(stderr, "timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + if (timestamp_logprob > max_text_token_logprob) { - for (int i = 0; i < vocab.token_beg; ++ i) { - logits[i] = -INFINITY; + for (int i = 0; i < vocab.token_beg; ++i) { + logits[i] = -INFINITY; + logprobs[i] = -INFINITY; } } } } + // compute probs + { + for (int i = 0; i < n_logits; ++i) { + if (logits[i] == -INFINITY) { + probs[i] = 0.0f; + } else { + probs[i] = expf(logprobs[i]); + } + } + } + #if 0 // print first 100 logits - token string : logit for (int i = 0; i < 100; i++) { - const auto token = vocab.id_to_token.at(i); - const auto logit = logits[i]; - printf("%s : %f\n", token.c_str(), logit); + const auto token = vocab.id_to_token.at(i); + const auto prob = probs[i]; + const auto logit = logits[i]; + const auto logprob = logprobs[i]; + printf("%s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); } // "And", "and", " And", " and" @@ -2994,26 +3009,13 @@ static void whisper_process_logits( printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); -#endif - - //switch (params.strategy) { - // case WHISPER_SAMPLING_GREEDY: - // { - // // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L249-L274 - // // TODO: implement - // result = (is_initial) ? whisper_sample_timestamp(ctx, true) : whisper_sample_best(ctx); - // } break; - // case WHISPER_SAMPLING_BEAM_SEARCH: - // { - // // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L277C13-L364 - // // TODO: implement - // } break; - //} - - //sum_logprobs += logprobs[result.id]; - //printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1)); - //return result; + printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); +#endif } // select the most probable token @@ -3052,13 +3054,6 @@ static whisper_token_data whisper_sample_best( } for (int i = 0; i < n_logits; ++i) { - // never sample these: - if (i == vocab.token_sot || - i == vocab.token_solm || - i == vocab.token_not) { - continue; - } - if (result.p < probs[i]) { result.id = i; result.p = probs[i]; @@ -3334,32 +3329,42 @@ int whisper_full( // print the prompt //printf("\n\n"); - //for (int i = 0; i < prompt.size(); i++) { + //for (int i = 0; i < (int) prompt.size(); i++) { // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); //} //printf("\n\n"); - if (!whisper_decode(*ctx, ctx->decoders[0], params.n_threads, prompt.data(), prompt.size(), 0)) { + if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } - whisper_process_logits(*ctx, ctx->decoders[0], params); + { + const int64_t t_start_sample_us = ggml_time_us(); - for (int j = 1; j < n_decoders_cur; ++j) { - auto & decoder = ctx->decoders[j]; + whisper_process_logits(*ctx, ctx->decoders[0], params); + + ctx->decoders[0].n_past += prompt.size(); + + for (int j = 1; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size()); - memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size()); + decoder.n_past += prompt.size(); - decoder.n_past += prompt.size(); + memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + } - memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); - memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); - memcpy(decoder.logprobs.data(), ctx->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } } for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + const int64_t t_start_sample_us = ggml_time_us(); + for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3407,13 +3412,14 @@ int whisper_full( { const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - printf("%s: %3d, decoder = %d, %10s %6d %6.3f '%s'\n", __func__, i, j, tt.c_str(), token.id, token.pt, ctx->vocab.id_to_token[token.id].c_str()); + printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token[token.id].c_str()); } // end of segment - if (token.id == whisper_token_eot(ctx) || // end of text token - (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached - (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached ) { if (result_len == 0) { if (seek + seek_delta + 100 >= seek_end) { @@ -3468,6 +3474,8 @@ int whisper_full( } } + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3480,14 +3488,20 @@ int whisper_full( //fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); - if (!whisper_decode(*ctx, decoder, params.n_threads, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past)) { + if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); return -7; } - whisper_process_logits(*ctx, decoder, params); + { + const int64_t t_start_sample_us = ggml_time_us(); + + whisper_process_logits(*ctx, decoder, params); + + ++decoder.n_past; - ++decoder.n_past; + ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + } } } @@ -3502,12 +3516,15 @@ int whisper_full( continue; } - whisper_sequence_score(params, ctx->decoders[j].sequence); + decoder.sequence.tokens.resize(decoder.sequence.result_len); + whisper_sequence_score(params, decoder.sequence); if (best_score < decoder.sequence.score) { best_score = decoder.sequence.score; best_decoder_id = j; } + + fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs); } } @@ -3671,9 +3688,8 @@ int whisper_full_parallel( ctx_p = *ctx; ctx_p.logits.reserve(ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); - ctx_p.probs.reserve (ctx_p.vocab.n_vocab*ctx_p.model.hparams.n_text_ctx); - ctx_p.probs_id.reserve(ctx_p.vocab.n_vocab); + ctx_p.logits_id.reserve(ctx_p.vocab.n_vocab); if (!kv_cache_reinit(ctx_p.kv_cross)) { fprintf(stderr, "%s: kv_cache_reinit() failed for cross-attention, processor %d\n", __func__, i); diff --git a/whisper.h b/whisper.h index 906c7c0ee68..8c94a7ad15d 100644 --- a/whisper.h +++ b/whisper.h @@ -191,12 +191,6 @@ extern "C" { // Cols: n_vocab WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); - // Token probabilities (i.e. softmax(logits)) obtained from the last call to whisper_decode() - // The probabilities for the last token are stored in the last row - // Rows: n_tokens - // Cols: n_vocab - WHISPER_API float * whisper_get_probs(struct whisper_context * ctx); - // Token Id -> String. Uses the vocabulary in the provided context WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); From 712bc4b960bfacf7a935190c341d14234dedfcfc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 14:11:22 +0200 Subject: [PATCH 12/23] whisper : fix prompt_past update to not include prompt_init --- whisper.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 6a1da95c013..38c994dfbcf 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3180,9 +3180,9 @@ int whisper_full( ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); - ctx->decoders[i].probs.reserve (ctx->vocab.n_vocab); - ctx->decoders[i].logits.reserve (ctx->vocab.n_vocab); - ctx->decoders[i].logprobs.reserve(ctx->vocab.n_vocab); + ctx->decoders[i].probs.resize (ctx->vocab.n_vocab); + ctx->decoders[i].logits.resize (ctx->vocab.n_vocab); + ctx->decoders[i].logprobs.resize(ctx->vocab.n_vocab); } } @@ -3557,7 +3557,7 @@ int whisper_full( // update prompt_past prompt_past.clear(); - prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); for (int i = 0; i < result_len; ++i) { prompt_past.push_back(tokens_cur[i].id); From 34c5110f5967d9cdd2d9e0f1fcbf76d80569947a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 15:35:08 +0200 Subject: [PATCH 13/23] whisper : temperature + best_of support --- whisper.cpp | 105 +++++++++++++++++++++++++++++++++------------------- 1 file changed, 67 insertions(+), 38 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 38c994dfbcf..dad11399220 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #define WHISPER_ASSERT(x) \ do { \ @@ -479,6 +480,8 @@ struct whisper_context { std::vector> logits_id; + mutable std::mt19937 rng; // used for sampling at t > 0.0 + // [EXPERIMENTAL] token-level timestamps data int64_t t_beg; int64_t t_last; @@ -1180,6 +1183,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } + wctx.rng = std::mt19937(0); + wctx.t_load_us = ggml_time_us() - t_start_us; return true; @@ -2848,9 +2853,10 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { // - applyies logit filters // - computes logprobs static void whisper_process_logits( - struct whisper_context & ctx, - struct whisper_decoder & decoder, - struct whisper_full_params params) { + const struct whisper_context & ctx, + const struct whisper_full_params params, + struct whisper_decoder & decoder, + float temperature) { const auto & vocab = ctx.vocab; const auto & tokens_cur = decoder.sequence.tokens; @@ -2868,6 +2874,12 @@ static void whisper_process_logits( logits.resize(n_logits); memcpy(logits.data(), ctx.logits.data() + (ctx.logits.size() - n_logits), n_logits*sizeof(float)); + if (temperature > 0.0f) { + for (int i = 0; i < n_logits; i++) { + logits[i] /= temperature; + } + } + // will be populated a bit later probs.resize(n_logits); logprobs.resize(n_logits); @@ -3018,10 +3030,10 @@ static void whisper_process_logits( #endif } -// select the most probable token -static whisper_token_data whisper_sample_best( - whisper_context & ctx, - whisper_decoder & decoder) { +static whisper_token_data whisper_sample_token( + const whisper_context & ctx, + const whisper_decoder & decoder, + bool best) { whisper_token_data result = { 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, }; @@ -3053,12 +3065,20 @@ static whisper_token_data whisper_sample_best( result.ptsum = sum_ts; } - for (int i = 0; i < n_logits; ++i) { - if (result.p < probs[i]) { - result.id = i; - result.p = probs[i]; - result.plog = logprobs[i]; + if (best) { + for (int i = 0; i < n_logits; ++i) { + if (result.p < probs[i]) { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } } + } else { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + result.id = dist(ctx.rng); + result.p = probs[result.id]; + result.plog = logprobs[result.id]; } return result; @@ -3167,22 +3187,24 @@ int whisper_full( } break; }; - for (int i = 1; i < n_decoders; i++) { - // TAGS: WHISPER_DECODER_INIT - if (ctx->decoders[i].kv_self.ctx == nullptr) { - ctx->decoders[i].kv_self = ctx->decoders[0].kv_self; - if (!kv_cache_reinit(ctx->decoders[i].kv_self)) { - fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, i); + // TAGS: WHISPER_DECODER_INIT + for (int j = 1; j < n_decoders; j++) { + auto & decoder = ctx->decoders[j]; + + if (decoder.kv_self.ctx == nullptr) { + decoder.kv_self = ctx->decoders[0].kv_self; + if (!kv_cache_reinit(decoder.kv_self)) { + fprintf(stderr, "%s: kv_cache_reinit() failed for self-attention, decoder %d\n", __func__, j); return -4; } - fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, i); + fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, j); - ctx->decoders[i].sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); + decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); - ctx->decoders[i].probs.resize (ctx->vocab.n_vocab); - ctx->decoders[i].logits.resize (ctx->vocab.n_vocab); - ctx->decoders[i].logprobs.resize(ctx->vocab.n_vocab); + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); } } @@ -3268,6 +3290,7 @@ int whisper_full( const float t_cur = temperatures[it]; int n_decoders_cur = 1; + switch (params.strategy) { case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: { @@ -3330,7 +3353,7 @@ int whisper_full( // print the prompt //printf("\n\n"); //for (int i = 0; i < (int) prompt.size(); i++) { - // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token[prompt[i]].c_str()); + // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); //} //printf("\n\n"); @@ -3342,7 +3365,7 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, ctx->decoders[0], params); + whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); ctx->decoders[0].n_past += prompt.size(); @@ -3376,8 +3399,9 @@ int whisper_full( case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: { if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_best(*ctx, decoder)); + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); } } break; case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: @@ -3402,7 +3426,7 @@ int whisper_full( // do not allow to go back in time if (has_ts && seek_delta > seek_delta_new && result_len < i) { failed = true; // TODO: maybe this is not a failure ? - break; + continue; } seek_delta = seek_delta_new; @@ -3410,11 +3434,11 @@ int whisper_full( has_ts = true; } - { - const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token[token.tid] : "[?]"; - printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", - __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token[token.id].c_str()); - } + //{ + // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + // printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + // __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + //} // end of segment if (token.id == whisper_token_eot(ctx) || // end of text token @@ -3426,7 +3450,7 @@ int whisper_full( result_len = i + 1; } else { failed = true; - break; + continue; } } @@ -3436,14 +3460,14 @@ int whisper_full( } completed = true; - break; + continue; } // TESTS: if no tensors are loaded, it means we are running tests if (ctx->model.n_loaded == 0) { seek_delta = 100*WHISPER_CHUNK_SIZE; completed = true; - break; + continue; } } @@ -3451,7 +3475,7 @@ int whisper_full( // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { failed = true; - break; + continue; } } @@ -3496,7 +3520,7 @@ int whisper_full( { const int64_t t_start_sample_us = ggml_time_us(); - whisper_process_logits(*ctx, decoder, params); + whisper_process_logits(*ctx, params, decoder, t_cur); ++decoder.n_past; @@ -3524,7 +3548,7 @@ int whisper_full( best_decoder_id = j; } - fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs); + //fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs); } } @@ -3541,6 +3565,11 @@ int whisper_full( } if (success) { + //fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id); + //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + //} + break; } From c67716faba15ff3ab5f1106281b06b4994c03c03 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 16:53:19 +0200 Subject: [PATCH 14/23] whisper : support for compression_ration_threshold We actually use entropy, but it is similar --- examples/main/main.cpp | 2 +- whisper.cpp | 73 +++++++++++++++++++++++++++++++++--------- whisper.h | 2 +- 3 files changed, 60 insertions(+), 17 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 48e02923d01..d149e49f46c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -235,7 +235,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi const char * text = whisper_full_get_token_text(ctx, i, j); const float p = whisper_full_get_token_p (ctx, i, j); - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); } diff --git a/whisper.cpp b/whisper.cpp index dad11399220..45181e31cdf 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -425,6 +425,7 @@ struct whisper_sequence { double sum_logprobs; // the sum of the log probabilities of the tokens double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens double score; // likelihood rank score }; @@ -2700,10 +2701,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_initial_timestamp =*/ 1.0f, /*.length_penalty =*/ -1.0f, - /*.temperature_increment =*/ 0.2f, - /*.compression_ratio_threshold =*/ 2.4f, - /*.logprob_threshold =*/ -1.0f, - /*.no_speech_threshold =*/ 0.6f, + /*.temperature_increment =*/ 0.2f, + /*.entropy_threshold =*/ 2.4f, + /*.logprob_threshold =*/ -1.0f, + /*.no_speech_threshold =*/ 0.6f, /*.greedy =*/ { /*.best_of =*/ 5, @@ -2760,10 +2761,10 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_initial_timestamp =*/ 1.0f, /*.length_penalty =*/ -1.0f, - /*.temperature_increment =*/ 0.2f, - /*.compression_ratio_threshold =*/ 2.4f, - /*.logprob_threshold =*/ -1.0f, - /*.no_speech_threshold =*/ 0.6f, + /*.temperature_increment =*/ 0.2f, + /*.entropy_threshold =*/ 2.4f, + /*.logprob_threshold =*/ -1.0f, + /*.no_speech_threshold =*/ 0.6f, /*.greedy =*/ { /*.best_of =*/ 5, @@ -3081,6 +3082,11 @@ static whisper_token_data whisper_sample_token( result.plog = logprobs[result.id]; } + if (result.id >= vocab.token_beg) { + result.tid = result.id; + result.pt = result.p; + } + return result; } @@ -3106,6 +3112,28 @@ static void whisper_sequence_score( } sequence.score = result/penalty; + + // compute the entropy of the sequence of the last 32 tokens + { + const int n = 32; + + int cnt = 0; + double entropy = 0.0f; + + std::map token_counts; + for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) { + token_counts[sequence.tokens[i].id]++; + cnt++; + } + + for (const auto & kv : token_counts) { + const auto p = kv.second/(double)cnt; + entropy -= p*log(p); + //printf("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + } + + sequence.entropy = entropy; + } } int whisper_full( @@ -3322,9 +3350,10 @@ int whisper_full( decoder.sequence.tokens.clear(); decoder.sequence.result_len = 0; - decoder.sequence.sum_logprobs = 0.0; - decoder.sequence.avg_logprobs = 0.0; - decoder.sequence.score = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0f; + decoder.sequence.score = -INFINITY; decoder.n_past = 0; decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; @@ -3543,12 +3572,22 @@ int whisper_full( decoder.sequence.tokens.resize(decoder.sequence.result_len); whisper_sequence_score(params, decoder.sequence); + fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); + + if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) { + fprintf(stderr, "%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + __func__, j, decoder.sequence.entropy, params.entropy_threshold); + + decoder.failed = true; + + continue; + } + if (best_score < decoder.sequence.score) { best_score = decoder.sequence.score; best_decoder_id = j; } - - //fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs); } } @@ -3559,13 +3598,13 @@ int whisper_full( { auto & decoder = ctx->decoders[best_decoder_id]; - if (decoder.sequence.avg_logprobs < params.logprob_threshold) { + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_threshold) { success = false; } } if (success) { - //fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id); + fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id); //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { // fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); //} @@ -3623,6 +3662,8 @@ int whisper_full( } } + //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); + result_all.push_back({ tt0, tt1, text, {} }); for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); @@ -3690,6 +3731,8 @@ int whisper_full( // update audio window seek += seek_delta; + + fprintf(stderr, "seek = %d, seek_delta = %d\n", seek, seek_delta); } } diff --git a/whisper.h b/whisper.h index 8c94a7ad15d..69a2246c0e1 100644 --- a/whisper.h +++ b/whisper.h @@ -279,7 +279,7 @@ extern "C" { // fallback parameters float temperature_increment; - float compression_ratio_threshold; + float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold float logprob_threshold; float no_speech_threshold; From 7ea1b736ec19301e2557449153a57352edf98018 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 17:01:24 +0200 Subject: [PATCH 15/23] command : fix example to use logits instead of obsolete probs --- examples/command/command.cpp | 107 +++++++++++++++++++++-------------- 1 file changed, 66 insertions(+), 41 deletions(-) diff --git a/examples/command/command.cpp b/examples/command/command.cpp index 3dae3a5e31c..2bdaf87c45c 100644 --- a/examples/command/command.cpp +++ b/examples/command/command.cpp @@ -671,56 +671,81 @@ int process_command_list(struct whisper_context * ctx, audio_async &audio, const break; } - const auto * probs = whisper_get_probs(ctx); - std::vector> probs_id; - - double psum = 0.0; - for (int i = 0; i < (int) allowed_commands.size(); ++i) { - probs_id.emplace_back(probs[allowed_tokens[i][0]], i); - for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { - probs_id.back().first += probs[allowed_tokens[i][j]]; - } - probs_id.back().first /= allowed_tokens[i].size(); - psum += probs_id.back().first; - } + // estimate command probability + // NOTE: not optimal + { + const auto * logits = whisper_get_logits(ctx); - // normalize - for (auto & p : probs_id) { - p.first /= psum; - } + std::vector probs(whisper_n_vocab(ctx), 0.0f); - // sort descending - { - using pair_type = decltype(probs_id)::value_type; - std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { - return a.first > b.first; - }); - } + // compute probs from logits via softmax + { + float max = -1e9; + for (int i = 0; i < (int) probs.size(); ++i) { + max = std::max(max, logits[i]); + } - // print the commands and the respective probabilities - { - fprintf(stdout, "\n"); - for (const auto & cmd : probs_id) { - fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); - for (int token : allowed_tokens[cmd.second]) { - fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]); + float sum = 0.0f; + for (int i = 0; i < (int) probs.size(); ++i) { + probs[i] = expf(logits[i] - max); + sum += probs[i]; + } + + for (int i = 0; i < (int) probs.size(); ++i) { + probs[i] /= sum; } + } + + std::vector> probs_id; + + double psum = 0.0; + for (int i = 0; i < (int) allowed_commands.size(); ++i) { + probs_id.emplace_back(probs[allowed_tokens[i][0]], i); + for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) { + probs_id.back().first += probs[allowed_tokens[i][j]]; + } + probs_id.back().first /= allowed_tokens[i].size(); + psum += probs_id.back().first; + } + + // normalize + for (auto & p : probs_id) { + p.first /= psum; + } + + // sort descending + { + using pair_type = decltype(probs_id)::value_type; + std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // print the commands and the respective probabilities + { fprintf(stdout, "\n"); + for (const auto & cmd : probs_id) { + fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first); + for (int token : allowed_tokens[cmd.second]) { + fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]); + } + fprintf(stdout, "\n"); + } } - } - // best command - { - const auto t_end = std::chrono::high_resolution_clock::now(); + // best command + { + const auto t_end = std::chrono::high_resolution_clock::now(); - const float prob = probs_id[0].first; - const int index = probs_id[0].second; + const float prob = probs_id[0].first; + const int index = probs_id[0].second; - fprintf(stdout, "\n"); - fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, - "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob, - (int) std::chrono::duration_cast(t_end - t_start).count()); - fprintf(stdout, "\n"); + fprintf(stdout, "\n"); + fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__, + "\033[1m", allowed_commands[index].c_str(), "\033[0m", prob, + (int) std::chrono::duration_cast(t_end - t_start).count()); + fprintf(stdout, "\n"); + } } audio.clear(); From c6a8a4703984e58dba2d45f538d25f83337ec8cf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 17:48:02 +0200 Subject: [PATCH 16/23] whisper : handle empty sequence ranking --- whisper.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/whisper.cpp b/whisper.cpp index 45181e31cdf..b964aa1f7eb 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3094,7 +3094,9 @@ static whisper_token_data whisper_sample_token( static void whisper_sequence_score( const struct whisper_full_params & params, whisper_sequence & sequence) { - WHISPER_ASSERT(sequence.result_len > 0); + if (sequence.result_len == 0) { + return; + } double result = 0.0f; From c301a7942b615f7fb102b5ebbcc9e0638d4f5b00 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 18:25:16 +0200 Subject: [PATCH 17/23] whisper : add WHISPER_DEBUG + diagnostic prints + new main args --- README.md | 12 +---- examples/main/main.cpp | 119 +++++++++++++++++++++++------------------ whisper.cpp | 48 ++++++++++------- whisper.h | 2 +- 4 files changed, 99 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index f22724a5054..448e7588059 100644 --- a/README.md +++ b/README.md @@ -212,17 +212,7 @@ make large ## Limitations - Inference only -- No GPU support -- Very basic greedy sampling scheme - always pick up the token with highest probability. - This should be similar to the [GreedyDecoder](https://github.com/openai/whisper/blob/main/whisper/decoding.py#L249-L274) - from the original python implementation, so in order to make a fair comparison between the 2 implementations, make sure - to run the python code with the following parameters: - - ``` - whisper --best_of None --beam_size None ... - ``` - - In the future, `whisper.cpp` will support more sampling strategies. +- No GPU support (yet) ## Another example diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d149e49f46c..18c434e6a88 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -59,8 +59,12 @@ struct whisper_params { int32_t duration_ms = 0; int32_t max_context = -1; int32_t max_len = 0; + int32_t best_of = 5; + int32_t beam_size = -1; - float word_thold = 0.01f; + float word_thold = 0.01f; + float entropy_thold = 2.4f; + float logprob_thold = -1.0f; bool speed_up = false; bool translate = false; @@ -104,7 +108,11 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } @@ -136,31 +144,35 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); - fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); - fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); - fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); - fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); fprintf(stderr, "\n"); } @@ -331,20 +343,19 @@ bool output_csv(struct whisper_context * ctx, const char * fname) { const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); - if (text[0] == ' ') - text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. + if (text[0] == ' ') { + text = text + sizeof(char); //whisper_full_get_segment_text() returns a string with leading space, point to the next character. + } const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. - fout << 10 * t0 << ", " - << 10 * t1 << ", \"" - << text << "\"\n"; + + //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. + fout << 10 * t0 << ", " << 10 * t1 << ", \"" << text << "\"\n"; } return true; } - // karaoke video generation // outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments @@ -620,25 +631,29 @@ int main(int argc, char ** argv) { { whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - wparams.print_realtime = false; - wparams.print_progress = params.print_progress; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special = params.print_special; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; - wparams.offset_ms = params.offset_t_ms; - wparams.duration_ms = params.duration_ms; - - wparams.token_timestamps = params.output_wts || params.max_len > 0; - wparams.thold_pt = params.word_thold; - wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; - - wparams.speed_up = params.speed_up; - - wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); - wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.entropy_threshold = params.entropy_thold; + wparams.logprob_threshold = params.logprob_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + + wparams.speed_up = params.speed_up; + + wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); + wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); whisper_print_user_data user_data = { ¶ms, &pcmf32s }; diff --git a/whisper.cpp b/whisper.cpp index b964aa1f7eb..385d183fa43 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -25,6 +25,16 @@ } \ } while (0) +#define WHISPER_DEBUG +#if defined(WHISPER_DEBUG) +#define WHISPER_PRINT_DEBUG(...) \ + do { \ + fprintf(stderr, __VA_ARGS__); \ + } while (0) +#else +#define WHISPER_PRINT_DEBUG(...) +#endif + #define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 @@ -1640,7 +1650,7 @@ static bool whisper_decode( const int N = n_tokens; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; - //fprintf(stderr, "n_past = %d, N = %d, M = %d, n_ctx = %d\n", n_past, N, M, n_ctx); + WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); struct ggml_init_params params; params.mem_size = wctx.buf_compute.size(); @@ -3177,8 +3187,8 @@ int whisper_full( } if (params.token_timestamps) { - ctx->t_beg = 0; - ctx->t_last = 0; + ctx->t_beg = 0; + ctx->t_last = 0; ctx->tid_last = 0; ctx->energy = get_signal_energy(samples, n_samples, 32); } @@ -3228,7 +3238,7 @@ int whisper_full( return -4; } - fprintf(stderr, "%s: initialized self-attention kv cache, decoder %d\n", __func__, j); + WHISPER_PRINT_DEBUG("%s: initialized self-attention kv cache, decoder %d\n", __func__, j); decoder.sequence.tokens.reserve(ctx->decoders[0].sequence.tokens.capacity()); @@ -3338,12 +3348,12 @@ int whisper_full( } break; }; - fprintf(stderr, "\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); + WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); if (t_cur > 0.5) { prompt_past.clear(); - fprintf(stderr, "%s: clearing prompt_past\n", __func__); + WHISPER_PRINT_DEBUG("%s: clearing prompt_past\n", __func__); } // TAGS: WHISPER_DECODER_INIT @@ -3465,11 +3475,13 @@ int whisper_full( has_ts = true; } - //{ - // const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; - // printf("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", - // __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); - //} +#ifdef WHISPER_DEBUG + { + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + WHISPER_PRINT_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + } +#endif // end of segment if (token.id == whisper_token_eot(ctx) || // end of text token @@ -3541,7 +3553,7 @@ int whisper_full( decoder.tokens_tmp.resize(1); decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; - //fprintf(stderr, "%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); @@ -3574,11 +3586,11 @@ int whisper_full( decoder.sequence.tokens.resize(decoder.sequence.result_len); whisper_sequence_score(params, decoder.sequence); - fprintf(stderr, "%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) { - fprintf(stderr, "%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", __func__, j, decoder.sequence.entropy, params.entropy_threshold); decoder.failed = true; @@ -3606,15 +3618,15 @@ int whisper_full( } if (success) { - fprintf(stderr, "%s: best decoder = %d\n", __func__, best_decoder_id); + WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { - // fprintf(stderr, "%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); //} break; } - fprintf(stderr, "\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); + WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } { @@ -3734,7 +3746,7 @@ int whisper_full( // update audio window seek += seek_delta; - fprintf(stderr, "seek = %d, seek_delta = %d\n", seek, seek_delta); + WHISPER_PRINT_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); } } diff --git a/whisper.h b/whisper.h index 69a2246c0e1..4bcb0e6ca07 100644 --- a/whisper.h +++ b/whisper.h @@ -281,7 +281,7 @@ extern "C" { float temperature_increment; float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold float logprob_threshold; - float no_speech_threshold; + float no_speech_threshold; // TODO: not implemented struct { int best_of; From 5e97f80fc5ac60c4247583de7399975a32124213 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 19:30:53 +0200 Subject: [PATCH 18/23] whisper : minor fixes --- examples/stream/stream.cpp | 6 +++++- whisper.cpp | 24 +++++++++++------------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index c7aa87178f9..c01a8df125a 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -459,7 +459,7 @@ int main(int argc, char ** argv) { struct whisper_context * ctx = whisper_init_from_file(params.model.c_str()); std::vector pcmf32 (n_samples_30s, 0.0f); - std::vector pcmf32_old(n_samples_30s, 0.0f); + std::vector pcmf32_old; std::vector pcmf32_new(n_samples_30s, 0.0f); std::vector prompt_tokens; @@ -615,6 +615,10 @@ int main(int argc, char ** argv) { wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; + // disable best_of fallback + wparams.temperature_increment = -1.0f; + wparams.greedy.best_of = -1; + wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); diff --git a/whisper.cpp b/whisper.cpp index 385d183fa43..73afa7c5ab9 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -25,7 +25,7 @@ } \ } while (0) -#define WHISPER_DEBUG +//#define WHISPER_DEBUG #if defined(WHISPER_DEBUG) #define WHISPER_PRINT_DEBUG(...) \ do { \ @@ -3216,6 +3216,7 @@ int whisper_full( // initialize the decoders int n_decoders = 1; + switch (params.strategy) { case WHISPER_SAMPLING_GREEDY: { @@ -3227,6 +3228,8 @@ int whisper_full( } break; }; + n_decoders = std::max(1, n_decoders); + // TAGS: WHISPER_DECODER_INIT for (int j = 1; j < n_decoders; j++) { auto & decoder = ctx->decoders[j]; @@ -3348,13 +3351,9 @@ int whisper_full( } break; }; - WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); - - if (t_cur > 0.5) { - prompt_past.clear(); + n_decoders_cur = std::max(1, n_decoders_cur); - WHISPER_PRINT_DEBUG("%s: clearing prompt_past\n", __func__); - } + WHISPER_PRINT_DEBUG("\n%s: decoding with %d decoders, temperature = %.2f\n", __func__, n_decoders_cur, t_cur); // TAGS: WHISPER_DECODER_INIT for (int j = 0; j < n_decoders_cur; ++j) { @@ -3381,7 +3380,7 @@ int whisper_full( prompt.clear(); // if we have already generated some text, use it as a prompt to condition the next generation - if (!prompt_past.empty()) { + if (!prompt_past.empty() && t_cur > 0.5f) { int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); prompt = { whisper_token_prev(ctx) }; @@ -3392,11 +3391,11 @@ int whisper_full( prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); // print the prompt - //printf("\n\n"); + //WHISPER_PRINT_DEBUG("\n\n"); //for (int i = 0; i < (int) prompt.size(); i++) { - // printf("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + // WHISPER_PRINT_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); //} - //printf("\n\n"); + //WHISPER_PRINT_DEBUG("\n\n"); if (!whisper_decode(*ctx, ctx->decoders[0], prompt.data(), prompt.size(), 0, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); @@ -3608,7 +3607,6 @@ int whisper_full( bool success = true; // implement logprob threshold - // TODO: implement compression threshold { auto & decoder = ctx->decoders[best_decoder_id]; @@ -3646,7 +3644,7 @@ int whisper_full( } // store the text from this iteration - if (!tokens_cur.empty()) { + if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { int i0 = 0; auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); From 5548a1986fe5b4885f666e5c99d292898f7f3cae Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 22:49:06 +0200 Subject: [PATCH 19/23] whisper : add beam-search support --- examples/main/main.cpp | 3 + whisper.cpp | 200 +++++++++++++++++++++++++++++++++++++---- whisper.h | 2 +- 3 files changed, 187 insertions(+), 18 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 18c434e6a88..d52e1d73e1e 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -652,6 +652,9 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); diff --git a/whisper.cpp b/whisper.cpp index 73afa7c5ab9..0fb0f426867 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -433,10 +433,11 @@ struct whisper_sequence { // the accumulated transcription in the current interation (used to truncate the tokens array) int result_len; - double sum_logprobs; // the sum of the log probabilities of the tokens - double avg_logprobs; // the average log probability of the tokens - double entropy; // the entropy of the tokens - double score; // likelihood rank score + double sum_logprobs_all; // the sum of the log probabilities of the tokens + double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens) + double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens + double score; // likelihood rank score }; // TAGS: WHISPER_DECODER_INIT @@ -1650,7 +1651,7 @@ static bool whisper_decode( const int N = n_tokens; const int M = wctx.exp_n_audio_ctx > 0 ? wctx.exp_n_audio_ctx : hparams.n_audio_ctx; - WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); + //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); struct ggml_init_params params; params.mem_size = wctx.buf_compute.size(); @@ -3100,6 +3101,74 @@ static whisper_token_data whisper_sample_token( return result; } +static std::vector whisper_sample_token_topk( + whisper_context & ctx, + const whisper_decoder & decoder, + int k) { + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logits = decoder.logits; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + auto & logits_id = ctx.logits_id; + + logits_id.clear(); + for (int i = 0; i < n_logits; ++i) { + logits_id.push_back({ logits[i], i }); + } + + std::partial_sort( + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + std::vector result(k); + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + result[0].tid = i; + } + } + + result[0].pt = max_ts/(sum_ts + 1e-10); + result[0].ptsum = sum_ts; + } + + for (int i = 0; i < k; ++i) { + result[i].id = logits_id[i].second; + result[i].p = probs[result[i].id]; + result[i].plog = logprobs[result[i].id]; + result[i].tid = result[0].tid; + result[i].pt = result[0].pt; + result[i].ptsum = result[0].ptsum; + result[i].t0 = -1; + result[i].t1 = -1; + result[i].vlen = 0.0f; + + if (result[i].id >= vocab.token_beg) { + result[i].tid = result[i].id; + result[i].pt = result[i].p; + } + } + + return result; +} + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 static void whisper_sequence_score( const struct whisper_full_params & params, @@ -3120,7 +3189,7 @@ static void whisper_sequence_score( double penalty = sequence.result_len; if (params.length_penalty > 0.0f) { - penalty = pow((5.0 + penalty) / 6.0, params.length_penalty); + penalty = pow((5.0 + penalty)/6.0, params.length_penalty); } sequence.score = result/penalty; @@ -3141,7 +3210,8 @@ static void whisper_sequence_score( for (const auto & kv : token_counts) { const auto p = kv.second/(double)cnt; entropy -= p*log(p); - //printf("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + + //WHISPER_PRINT_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); } sequence.entropy = entropy; @@ -3293,6 +3363,25 @@ int whisper_full( std::vector prompt; prompt.reserve(whisper_n_text_ctx(ctx)); + // beam-search helpers + struct kv_buf { + std::vector k; + std::vector v; + }; + + std::vector kv_bufs; + + struct beam_candidate { + int decoder_idx; + int seek_delta; + + bool has_ts; + + whisper_sequence sequence; + }; + + std::vector beam_candidates; + // main loop while (true) { const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); @@ -3360,11 +3449,12 @@ int whisper_full( auto & decoder = ctx->decoders[j]; decoder.sequence.tokens.clear(); - decoder.sequence.result_len = 0; - decoder.sequence.sum_logprobs = -INFINITY; - decoder.sequence.avg_logprobs = -INFINITY; - decoder.sequence.entropy = 0.0f; - decoder.sequence.score = -INFINITY; + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs_all = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0; + decoder.sequence.score = -INFINITY; decoder.n_past = 0; decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; @@ -3412,7 +3502,8 @@ int whisper_full( for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; - memcpy(decoder.kv_self.buf.data(), ctx->decoders[0].kv_self.buf.data(), decoder.kv_self.buf.size()); + memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); + memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); decoder.n_past += prompt.size(); @@ -3428,6 +3519,25 @@ int whisper_full( for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { const int64_t t_start_sample_us = ggml_time_us(); + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + kv_bufs.resize(n_decoders_cur); + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + kv_bufs[j].k.resize(ggml_nbytes(decoder.kv_self.k)); + kv_bufs[j].v.resize(ggml_nbytes(decoder.kv_self.v)); + + memcpy(kv_bufs[j].k.data(), decoder.kv_self.k->data, kv_bufs[j].k.size()); + memcpy(kv_bufs[j].v.data(), decoder.kv_self.v->data, kv_bufs[j].v.size()); + } + + beam_candidates.clear(); + } + for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3443,12 +3553,65 @@ int whisper_full( } else { decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; } break; case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: { - // TODO: .. + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence }); + beam_candidates.back().sequence.tokens.push_back(token); + beam_candidates.back().sequence.sum_logprobs_all += token.plog; + + //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); + } } break; }; + } + + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + std::sort( + beam_candidates.begin(), + beam_candidates.end(), + [](const beam_candidate & a, const beam_candidate & b) { + return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; + }); + + int cur_c = 0; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & cur = beam_candidates[cur_c++]; + + while (beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) { + ++cur_c; + } + + decoder.sequence = cur.sequence; + decoder.seek_delta = cur.seek_delta; + decoder.has_ts = cur.has_ts; + + memcpy(decoder.kv_self.k->data, kv_bufs[cur.decoder_idx].k.data(), kv_bufs[cur.decoder_idx].k.size()); + memcpy(decoder.kv_self.v->data, kv_bufs[cur.decoder_idx].v.data(), kv_bufs[cur.decoder_idx].v.size()); + + WHISPER_PRINT_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", + __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); + } + } + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = ctx->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } auto & has_ts = decoder.has_ts; auto & failed = decoder.failed; @@ -3659,6 +3822,7 @@ int whisper_full( } else { text += whisper_token_to_str(ctx, tokens_cur[i].id); } + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); if (!text.empty()) { @@ -4077,6 +4241,8 @@ static void whisper_exp_compute_token_level_timestamps( p1--; } + //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); + if (p1 > p0) { double psum = 0.0; for (int j = p0; j <= p1; j++) { @@ -4205,11 +4371,11 @@ static void whisper_exp_compute_token_level_timestamps( // debug info //for (int j = 0; j < n; ++j) { // const auto & token = tokens[j]; - // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, - // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id)); + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id)); - // if (tokens[j].id >= whisper_token_eot(ctx)) { + // if (tokens[j].id >= whisper_token_eot(&ctx)) { // continue; // } //} diff --git a/whisper.h b/whisper.h index 4bcb0e6ca07..1266e0f974a 100644 --- a/whisper.h +++ b/whisper.h @@ -290,7 +290,7 @@ extern "C" { struct { int beam_size; - float patience; + float patience; // TODO: not implemented } beam_search; whisper_new_segment_callback new_segment_callback; From 6700cd57f705fecce8a2f80a7b17f8b944573b18 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Jan 2023 23:04:16 +0200 Subject: [PATCH 20/23] whisper : bug fix when there no previous context --- whisper.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/whisper.cpp b/whisper.cpp index 0fb0f426867..aad77b8e175 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3798,9 +3798,13 @@ int whisper_full( const auto & tokens_cur = best_decoder.sequence.tokens; + //WHISPER_PRINT_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); + // update prompt_past prompt_past.clear(); - prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + if (prompt.front() == whisper_token_prev(ctx)) { + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + } for (int i = 0; i < result_len; ++i) { prompt_past.push_back(tokens_cur[i].id); From d83e47573b85732bba732840916f8a20a04aae9b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Jan 2023 09:04:40 +0200 Subject: [PATCH 21/23] whisper : add comments --- examples/main/main.cpp | 36 ++--- examples/stream/stream.cpp | 4 +- whisper.cpp | 264 +++++++++++++++++-------------------- whisper.h | 45 ++++--- 4 files changed, 163 insertions(+), 186 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index d52e1d73e1e..7dd9800f091 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -633,24 +633,24 @@ int main(int argc, char ** argv) { wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; - wparams.print_realtime = false; - wparams.print_progress = params.print_progress; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special = params.print_special; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; - wparams.offset_ms = params.offset_t_ms; - wparams.duration_ms = params.duration_ms; - - wparams.token_timestamps = params.output_wts || params.max_len > 0; - wparams.thold_pt = params.word_thold; - wparams.entropy_threshold = params.entropy_thold; - wparams.logprob_threshold = params.logprob_thold; - wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; - - wparams.speed_up = params.speed_up; + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + + wparams.speed_up = params.speed_up; wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index c01a8df125a..3432cb5331a 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -616,8 +616,8 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; // disable best_of fallback - wparams.temperature_increment = -1.0f; - wparams.greedy.best_of = -1; + wparams.temperature_inc = -1.0f; + wparams.greedy.best_of = -1; wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); diff --git a/whisper.cpp b/whisper.cpp index aad77b8e175..c40085675ba 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -25,7 +25,9 @@ } \ } while (0) +// define this to enable verbose trace logging - useful for debugging purposes //#define WHISPER_DEBUG + #if defined(WHISPER_DEBUG) #define WHISPER_PRINT_DEBUG(...) \ do { \ @@ -380,6 +382,8 @@ struct whisper_kv_cache { struct ggml_context * ctx; std::vector buf; + + int n; // number of tokens currently in the cache }; struct whisper_model { @@ -442,12 +446,13 @@ struct whisper_sequence { // TAGS: WHISPER_DECODER_INIT struct whisper_decoder { + // each decoders keeps its own KV-cache whisper_kv_cache kv_self; + // the currently generated sequence of tokens whisper_sequence sequence; - int n_past; - int seek_delta; + int seek_delta; // the window shift found so far based on the decoded timestamp tokens bool failed; // has the current segment failed to decode? bool completed; // has the decoder completed the current segment? @@ -476,6 +481,8 @@ struct whisper_context { whisper_model model; whisper_vocab vocab; + // cross-attention KV cache for the decoders + // shared between all decoders whisper_kv_cache kv_cross; whisper_decoder decoders[WHISPER_MAX_DECODERS] = {}; @@ -490,6 +497,7 @@ struct whisper_context { std::vector result_all; std::vector prompt_past; + // work container used to avoid memory allocations std::vector> logits_id; mutable std::mt19937 rng; // used for sampling at t > 0.0 @@ -680,6 +688,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con mem_required / 1024.0 / 1024.0, mem_required_decoder / 1024.0 / 1024.0); } + // initialize all memory buffers + // always have at least one decoder + wctx.model.buf = new std::vector(); wctx.model.buf->resize(scale*MEM_REQ_MODEL.at(model.type)); @@ -2671,127 +2682,77 @@ const char * whisper_print_system_info(void) { //////////////////////////////////////////////////////////////////////////// struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { - struct whisper_full_params result; + struct whisper_full_params result = { + /*.strategy =*/ WHISPER_SAMPLING_GREEDY, - switch (strategy) { - case WHISPER_SAMPLING_GREEDY: - { - result = { - /*.strategy =*/ WHISPER_SAMPLING_GREEDY, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, + /*.translate =*/ false, + /*.no_context =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.max_tokens =*/ 0, - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, + /*.speed_up =*/ false, + /*.audio_ctx =*/ 0, - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, + /*.language =*/ "en", - /*.language =*/ "en", + /*.suppress_blank =*/ true, - /*.suppress_blank =*/ true, + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, - /*.temperature =*/ 0.0f, - /*.max_initial_timestamp =*/ 1.0f, - /*.length_penalty =*/ -1.0f, + /*.temperature_inc =*/ 0.2f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, - /*.temperature_increment =*/ 0.2f, - /*.entropy_threshold =*/ 2.4f, - /*.logprob_threshold =*/ -1.0f, - /*.no_speech_threshold =*/ 0.6f, + /*.greedy =*/ { + /*.best_of =*/ -1, + }, - /*.greedy =*/ { - /*.best_of =*/ 5, - }, + /*.beam_search =*/ { + /*.beam_size =*/ -1, - /*.beam_search =*/ { - /*.beam_size =*/ -1, + /*.patience =*/ -1.0f, + }, - /*.patience =*/ -1.0f, - }, + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + }; - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, + switch (strategy) { + case WHISPER_SAMPLING_GREEDY: + { + result.greedy = { + /*.best_of =*/ 1, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: { - result = { - /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, - - /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), - /*.n_max_text_ctx =*/ 16384, - /*.offset_ms =*/ 0, - /*.duration_ms =*/ 0, - - /*.translate =*/ false, - /*.no_context =*/ false, - /*.single_segment =*/ false, - /*.print_special =*/ false, - /*.print_progress =*/ true, - /*.print_realtime =*/ false, - /*.print_timestamps =*/ true, - - /*.token_timestamps =*/ false, - /*.thold_pt =*/ 0.01f, - /*.thold_ptsum =*/ 0.01f, - /*.max_len =*/ 0, - /*.max_tokens =*/ 0, - - /*.speed_up =*/ false, - /*.audio_ctx =*/ 0, - - /*.prompt_tokens =*/ nullptr, - /*.prompt_n_tokens =*/ 0, - - /*.language =*/ "en", - - /*.suppress_blank =*/ true, - - /*.temperature =*/ 0.0f, - /*.max_initial_timestamp =*/ 1.0f, - /*.length_penalty =*/ -1.0f, - - /*.temperature_increment =*/ 0.2f, - /*.entropy_threshold =*/ 2.4f, - /*.logprob_threshold =*/ -1.0f, - /*.no_speech_threshold =*/ 0.6f, + result.beam_search = { + /*.beam_size =*/ 5, - /*.greedy =*/ { - /*.best_of =*/ 5, - }, - - /*.beam_search =*/ { - /*.beam_size =*/ 5, - - /*.patience =*/ -1.0f, - }, - - /*.new_segment_callback =*/ nullptr, - /*.new_segment_callback_user_data =*/ nullptr, - - /*.encoder_begin_callback =*/ nullptr, - /*.encoder_begin_callback_user_data =*/ nullptr, + /*.patience =*/ -1.0f, }; } break; } @@ -2862,8 +2823,8 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) { } // process the logits for the selected decoder -// - applyies logit filters -// - computes logprobs +// - applies logit filters +// - computes logprobs and probs static void whisper_process_logits( const struct whisper_context & ctx, const struct whisper_full_params params, @@ -2938,11 +2899,11 @@ static void whisper_process_logits( } } - // the initial timestamp cannot be larger than max_initial_timestamp + // the initial timestamp cannot be larger than max_initial_ts // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 - if (is_initial && params.max_initial_timestamp > 0.0f) { + if (is_initial && params.max_initial_ts > 0.0f) { const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; - const int tid0 = std::round(params.max_initial_timestamp/precision); + const int tid0 = std::round(params.max_initial_ts/precision); for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { logits[i] = -INFINITY; @@ -3073,7 +3034,7 @@ static whisper_token_data whisper_sample_token( } } - result.pt = max_ts/(sum_ts + 1e-10); + result.pt = max_ts/(sum_ts + 1e-10); result.ptsum = sum_ts; } @@ -3127,7 +3088,13 @@ static std::vector whisper_sample_token_topk( return a.first > b.first; }); - std::vector result(k); + std::vector result; + result.reserve(k); + + whisper_token tid; + + float pt; + float ptsum; { double sum_ts = 0.0; @@ -3141,24 +3108,18 @@ static std::vector whisper_sample_token_topk( sum_ts += probs[i]; if (max_ts < probs[i]) { max_ts = probs[i]; - result[0].tid = i; + tid = i; } } - result[0].pt = max_ts/(sum_ts + 1e-10); - result[0].ptsum = sum_ts; + pt = max_ts/(sum_ts + 1e-10); + ptsum = sum_ts; } for (int i = 0; i < k; ++i) { - result[i].id = logits_id[i].second; - result[i].p = probs[result[i].id]; - result[i].plog = logprobs[result[i].id]; - result[i].tid = result[0].tid; - result[i].pt = result[0].pt; - result[i].ptsum = result[0].ptsum; - result[i].t0 = -1; - result[i].t1 = -1; - result[i].vlen = 0.0f; + const auto id = logits_id[i].second; + + result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); if (result[i].id >= vocab.token_beg) { result[i].tid = result[i].id; @@ -3276,8 +3237,8 @@ int whisper_full( // a set of temperatures to use // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] std::vector temperatures; - if (params.temperature_increment > 0.0f) { - for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_increment) { + if (params.temperature_inc > 0.0f) { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) { temperatures.push_back(t); } } else { @@ -3448,6 +3409,8 @@ int whisper_full( for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; + decoder.kv_self.n = 0; + decoder.sequence.tokens.clear(); decoder.sequence.result_len = 0; decoder.sequence.sum_logprobs_all = 0.0; @@ -3456,7 +3419,6 @@ int whisper_full( decoder.sequence.entropy = 0.0; decoder.sequence.score = -INFINITY; - decoder.n_past = 0; decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; decoder.failed = false; @@ -3497,7 +3459,7 @@ int whisper_full( whisper_process_logits(*ctx, params, ctx->decoders[0], t_cur); - ctx->decoders[0].n_past += prompt.size(); + ctx->decoders[0].kv_self.n += prompt.size(); for (int j = 1; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3505,7 +3467,7 @@ int whisper_full( memcpy(decoder.kv_self.k->data, ctx->decoders[0].kv_self.k->data, ggml_nbytes(decoder.kv_self.k)); memcpy(decoder.kv_self.v->data, ctx->decoders[0].kv_self.v->data, ggml_nbytes(decoder.kv_self.v)); - decoder.n_past += prompt.size(); + decoder.kv_self.n += prompt.size(); memcpy(decoder.probs.data(), ctx->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); memcpy(decoder.logits.data(), ctx->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); @@ -3519,6 +3481,7 @@ int whisper_full( for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { const int64_t t_start_sample_us = ggml_time_us(); + // store the KV caches of all decoders when doing beam-search if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { kv_bufs.resize(n_decoders_cur); for (int j = 0; j < n_decoders_cur; ++j) { @@ -3538,6 +3501,7 @@ int whisper_full( beam_candidates.clear(); } + // generate new sequence candidates for each decoder for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3571,6 +3535,7 @@ int whisper_full( }; } + // for beam-search, choose the top candidates and update the KV caches if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { std::sort( beam_candidates.begin(), @@ -3606,6 +3571,10 @@ int whisper_full( } } + // update the decoder state + // - check if the sequence is completed + // - check if the sequence is failed + // - update sliding window based on timestamp tokens for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3705,6 +3674,7 @@ int whisper_full( ctx->t_sample_us += ggml_time_us() - t_start_sample_us; + // obtain logits for the next token for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3715,11 +3685,11 @@ int whisper_full( decoder.tokens_tmp.resize(1); decoder.tokens_tmp[0] = decoder.sequence.tokens.back().id; - //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, n_past %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.n_past, decoder.seek_delta); + //WHISPER_PRINT_DEBUG("%s: decoder %d: token %d, kv_self.n %d, seek_delta %d\n", __func__, j, decoder.tokens_tmp[0], decoder.kv_self.n, decoder.seek_delta); - if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.n_past, params.n_threads)) { + if (!whisper_decode(*ctx, decoder, decoder.tokens_tmp.data(), decoder.tokens_tmp.size(), decoder.kv_self.n, params.n_threads)) { fprintf(stderr, "%s: failed to decode\n", __func__); - return -7; + return -8; } { @@ -3727,7 +3697,7 @@ int whisper_full( whisper_process_logits(*ctx, params, decoder, t_cur); - ++decoder.n_past; + ++decoder.kv_self.n; ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } @@ -3736,7 +3706,7 @@ int whisper_full( // rank the resulting sequences and select the best one { - double best_score = -1e9; + double best_score = -INFINITY; for (int j = 0; j < n_decoders_cur; ++j) { auto & decoder = ctx->decoders[j]; @@ -3751,9 +3721,9 @@ int whisper_full( WHISPER_PRINT_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); - if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_threshold) { + if (decoder.sequence.result_len > 8 && decoder.sequence.entropy < params.entropy_thold) { WHISPER_PRINT_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", - __func__, j, decoder.sequence.entropy, params.entropy_threshold); + __func__, j, decoder.sequence.entropy, params.entropy_thold); decoder.failed = true; @@ -3765,31 +3735,33 @@ int whisper_full( best_decoder_id = j; } } - } - bool success = true; + WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); + } - // implement logprob threshold + // was the decoding successful for the current temperature? { - auto & decoder = ctx->decoders[best_decoder_id]; + bool success = true; + + const auto & decoder = ctx->decoders[best_decoder_id]; - if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_threshold) { + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { success = false; } - } - if (success) { - WHISPER_PRINT_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); - //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { - // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); - //} + if (success) { + //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // WHISPER_PRINT_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + //} - break; + break; + } } WHISPER_PRINT_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); } + // output results through a user-provided callback { const auto & best_decoder = ctx->decoders[best_decoder_id]; diff --git a/whisper.h b/whisper.h index 1266e0f974a..84504b7b23f 100644 --- a/whisper.h +++ b/whisper.h @@ -137,6 +137,7 @@ extern "C" { // tokens + n_tokens is the provided context for the decoder. // n_past is the number of tokens to use from previous decoder calls. // Returns 0 on success + // TODO: add support for multiple decoders WHISPER_API int whisper_decode( struct whisper_context * ctx, const whisper_token * tokens, @@ -218,8 +219,8 @@ extern "C" { // Available sampling strategies enum whisper_sampling_strategy { - WHISPER_SAMPLING_GREEDY, // Always select the most probable token - WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! + WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder + WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder }; // Text segment callback @@ -239,17 +240,17 @@ extern "C" { enum whisper_sampling_strategy strategy; int n_threads; - int n_max_text_ctx; + int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder int offset_ms; // start offset in ms int duration_ms; // audio duration to process in ms bool translate; - bool no_context; + bool no_context; // do not use initial prompt for the decoder (if any) bool single_segment; // force single segment output (useful for streaming) - bool print_special; - bool print_progress; - bool print_realtime; - bool print_timestamps; + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime // [EXPERIMENTAL] token-level timestamps bool token_timestamps; // enable token-level timestamps @@ -259,10 +260,11 @@ extern "C" { int max_tokens; // max tokens per segment (0 = no limit) // [EXPERIMENTAL] speed-up techniques + // note: these can significantly reduce the quality of the output bool speed_up; // speed-up the audio by 2x using Phase Vocoder int audio_ctx; // overwrite the audio context size (0 = use default) - // tokens to provide the whisper model as initial prompt + // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call const whisper_token * prompt_tokens; int prompt_n_tokens; @@ -271,31 +273,34 @@ extern "C" { const char * language; // common decoding parameters: - bool suppress_blank; + bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 - float temperature; - float max_initial_timestamp; - float length_penalty; + float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 + float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 // fallback parameters - float temperature_increment; - float entropy_threshold; // analog tho OpenAI's compression_ratio_threshold - float logprob_threshold; - float no_speech_threshold; // TODO: not implemented + // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + float temperature_inc; + float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" + float logprob_thold; + float no_speech_thold; // TODO: not implemented struct { - int best_of; + int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 } greedy; struct { - int beam_size; + int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 - float patience; // TODO: not implemented + float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf } beam_search; + // called for every newly generated text segment whisper_new_segment_callback new_segment_callback; void * new_segment_callback_user_data; + // called each time before the encoder starts whisper_encoder_begin_callback encoder_begin_callback; void * encoder_begin_callback_user_data; }; From 3fe33d61a20ec89a50976df30431ad49c716b966 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Jan 2023 11:07:52 +0200 Subject: [PATCH 22/23] stream : disable temperature fallback For real-time processing, we always want a single decoder running at T=0 --- examples/main/main.cpp | 1 + examples/stream.wasm/emscripten.cpp | 3 +++ examples/stream/stream.cpp | 3 +-- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7dd9800f091..65b06ca516a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -654,6 +654,7 @@ int main(int argc, char ** argv) { wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; + wparams.temperature_inc = -1; wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); diff --git a/examples/stream.wasm/emscripten.cpp b/examples/stream.wasm/emscripten.cpp index e4cdf639a40..144a14d268f 100644 --- a/examples/stream.wasm/emscripten.cpp +++ b/examples/stream.wasm/emscripten.cpp @@ -49,6 +49,9 @@ void stream_main(size_t index) { wparams.max_tokens = 32; wparams.audio_ctx = 768; // partial encoder context for better performance + // disable temperature fallback + wparams.temperature_inc = -1.0f; + wparams.language = "en"; printf("stream: using %d threads\n", wparams.n_threads); diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 3432cb5331a..e1251704f5d 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -615,9 +615,8 @@ int main(int argc, char ** argv) { wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; - // disable best_of fallback + // disable temperature fallback wparams.temperature_inc = -1.0f; - wparams.greedy.best_of = -1; wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); From 6a2f4dbcb39e5e2f2b5c5977afb2768a1eff4bc2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Jan 2023 11:25:00 +0200 Subject: [PATCH 23/23] whisper.swiftui : update example - fix paths + add empty folders --- .../whisper.swiftui.demo/Resources/models/.gitignore | 0 .../whisper.swiftui.demo/Resources/samples/.gitignore | 0 .../whisper.swiftui.xcodeproj/project.pbxproj | 11 ++++++----- 3 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore create mode 100644 examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore diff --git a/examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore b/examples/whisper.swiftui/whisper.swiftui.demo/Resources/models/.gitignore new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore b/examples/whisper.swiftui/whisper.swiftui.demo/Resources/samples/.gitignore new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj index 9cc09c09b52..cc0afbcae4f 100644 --- a/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj +++ b/examples/whisper.swiftui/whisper.swiftui.xcodeproj/project.pbxproj @@ -35,10 +35,10 @@ 0AAC5DA029539CD0003032C3 /* WhisperCppDemo.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = WhisperCppDemo.entitlements; sourceTree = ""; }; 0AAC5DA229539CD0003032C3 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; 0AAC5DC629539EAF003032C3 /* WhisperCppDemo-Bridging-Header.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "WhisperCppDemo-Bridging-Header.h"; sourceTree = ""; }; - 0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = whisper.cpp; path = ../../../whisper.cpp; sourceTree = ""; }; - 0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = whisper.h; path = ../../../whisper.h; sourceTree = ""; }; - 0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; name = ggml.c; path = ../../../ggml.c; sourceTree = ""; }; - 0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; name = ggml.h; path = ../../../ggml.h; sourceTree = ""; }; + 0AAC5DC729539EB0003032C3 /* whisper.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = whisper.cpp; sourceTree = ""; }; + 0AAC5DC829539EB0003032C3 /* whisper.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = whisper.h; sourceTree = ""; }; + 0AAC5DC929539EB0003032C3 /* ggml.c */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.c; path = ggml.c; sourceTree = ""; }; + 0AAC5DCA29539EB0003032C3 /* ggml.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ggml.h; sourceTree = ""; }; 0AAC5DCD2953A05C003032C3 /* WhisperState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = WhisperState.swift; sourceTree = ""; }; 0AAC5DD02953A394003032C3 /* LibWhisper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibWhisper.swift; sourceTree = ""; }; /* End PBXFileReference section */ @@ -129,7 +129,8 @@ 0AAC5DC729539EB0003032C3 /* whisper.cpp */, 0AAC5DC829539EB0003032C3 /* whisper.h */, ); - path = whisper.cpp; + name = whisper.cpp; + path = ../..; sourceTree = ""; }; 0AAC5DCF2953A36C003032C3 /* whisper.cpp.swift */ = {