diff --git a/whisper.cpp b/whisper.cpp index ffff47ed543..b0c710de2b8 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -620,6 +620,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx); wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx); + wctx.work_logits.reserve(vocab.n_vocab); + wctx.work_logprobs.reserve(vocab.n_vocab); + vocab.probs_id.reserve(n_vocab); } @@ -1004,11 +1007,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements); } - const size_t memory_size = - ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) + - ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v); fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0); + fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0); } // load weights @@ -2579,6 +2582,10 @@ int whisper_is_multilingual(struct whisper_context * ctx) { return ctx->vocab.is_multilingual() ? 1 : 0; } +float * whisper_get_logits(struct whisper_context * ctx) { + return ctx->logits.data(); +} + float * whisper_get_probs(struct whisper_context * ctx) { return ctx->probs.data(); } @@ -2841,6 +2848,7 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { static struct whisper_token_data whisper_sample_next_token( struct whisper_context * ctx, struct whisper_full_params params, + double & sum_logprobs, const std::vector & prompt, const std::vector & tokens_cur) { struct whisper_token_data result = {}; @@ -2848,7 +2856,7 @@ static struct whisper_token_data whisper_sample_next_token( const auto & vocab = ctx->vocab; const bool is_initial = tokens_cur.size() == 0; - const int n_logits = vocab.id_to_token.size(); + const int n_logits = vocab.id_to_token.size(); WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab); @@ -2948,6 +2956,7 @@ static struct whisper_token_data whisper_sample_next_token( } } +#if 0 // print first 100 logits - token string : logit for (int i = 0; i < 100; i++) { const auto token = vocab.id_to_token.at(i); @@ -2967,6 +2976,7 @@ static struct whisper_token_data whisper_sample_next_token( printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); +#endif switch (params.strategy) { case WHISPER_SAMPLING_GREEDY: @@ -2982,6 +2992,9 @@ static struct whisper_token_data whisper_sample_next_token( } break; } + sum_logprobs += logprobs[result.id]; + printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1)); + return result; } @@ -3150,6 +3163,8 @@ int whisper_full( bool failed = false; // has the current segment failed to decode? bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? + double sum_logprobs = 0.0; + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) { fprintf(stderr, "%s: failed to decode\n", __func__); @@ -3162,7 +3177,7 @@ int whisper_full( // sample the next token based on the selected decoding strategy + parameters // also, update the sliding window position based on the sampled timestamp tokens { - const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur); + const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur); // timestamp token - update sliding window if (token.id > whisper_token_beg(ctx)) { diff --git a/whisper.h b/whisper.h index 1f7c82ec740..cdb1fb6cb14 100644 --- a/whisper.h +++ b/whisper.h @@ -145,7 +145,7 @@ extern "C" { // Token sampling methods. // These are provided for convenience and can be used after each call to whisper_decode(). - // You can also implement your own sampling method using the whisper_get_probs() function. + // You can also implement your own sampling method using the whiper_get_logits() or whisper_get_probs() functions. // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); @@ -192,7 +192,16 @@ extern "C" { WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx); - // The probabilities for the next token + // Token logits obtained from the last call to whisper_decode() + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + WHISPER_API float * whisper_get_logits(struct whisper_context * ctx); + + // Token probabilities (i.e. softmax(logits)) obtained from the last call to whisper_decode() + // The probabilities for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab WHISPER_API float * whisper_get_probs(struct whisper_context * ctx); // Token Id -> String. Uses the vocabulary in the provided context