Skip to content

Commit

Permalink
whisper : add whisper_get_logits()
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Jan 8, 2023
1 parent a1c8a65 commit e3c6416
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
25 changes: 20 additions & 5 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);

wctx.work_logits.reserve(vocab.n_vocab);
wctx.work_logprobs.reserve(vocab.n_vocab);

vocab.probs_id.reserve(n_vocab);
}

Expand Down Expand Up @@ -1004,11 +1007,11 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
model.memory_cross_v = ggml_new_tensor_1d(ctx, wtype, n_elements);
}

const size_t memory_size =
ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v) +
ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
const size_t memory_cross_size = ggml_nbytes(model.memory_cross_k) + ggml_nbytes(model.memory_cross_v);

fprintf(stderr, "%s: memory size = %7.2f MB\n", __func__, memory_size/1024.0/1024.0);
fprintf(stderr, "%s: memory x size = %7.2f MB\n", __func__, memory_cross_size/1024.0/1024.0);
}

// load weights
Expand Down Expand Up @@ -2579,6 +2582,10 @@ int whisper_is_multilingual(struct whisper_context * ctx) {
return ctx->vocab.is_multilingual() ? 1 : 0;
}

float * whisper_get_logits(struct whisper_context * ctx) {
return ctx->logits.data();
}

float * whisper_get_probs(struct whisper_context * ctx) {
return ctx->probs.data();
}
Expand Down Expand Up @@ -2841,14 +2848,15 @@ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
static struct whisper_token_data whisper_sample_next_token(
struct whisper_context * ctx,
struct whisper_full_params params,
double & sum_logprobs,
const std::vector<whisper_token> & prompt,
const std::vector<whisper_token_data> & tokens_cur) {
struct whisper_token_data result = {};

const auto & vocab = ctx->vocab;

const bool is_initial = tokens_cur.size() == 0;
const int n_logits = vocab.id_to_token.size();
const int n_logits = vocab.id_to_token.size();

WHISPER_ASSERT(n_logits == ctx->vocab.n_vocab);

Expand Down Expand Up @@ -2948,6 +2956,7 @@ static struct whisper_token_data whisper_sample_next_token(
}
}

#if 0
// print first 100 logits - token string : logit
for (int i = 0; i < 100; i++) {
const auto token = vocab.id_to_token.at(i);
Expand All @@ -2967,6 +2976,7 @@ static struct whisper_token_data whisper_sample_next_token(
printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]);
printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]);
printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]);
#endif

switch (params.strategy) {
case WHISPER_SAMPLING_GREEDY:
Expand All @@ -2982,6 +2992,9 @@ static struct whisper_token_data whisper_sample_next_token(
} break;
}

sum_logprobs += logprobs[result.id];
printf("token: %16s, logprob: %8.3f, sum_logprobs: %8.3f, avg_logprob: %8.3f\n", vocab.id_to_token.at(result.id).c_str(), logprobs[result.id], sum_logprobs, sum_logprobs/(tokens_cur.size() + 1));

return result;
}

Expand Down Expand Up @@ -3150,6 +3163,8 @@ int whisper_full(
bool failed = false; // has the current segment failed to decode?
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?

double sum_logprobs = 0.0;

for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) {
if (whisper_decode(ctx, prompt.data(), prompt.size(), n_past, params.n_threads) != 0) {
fprintf(stderr, "%s: failed to decode\n", __func__);
Expand All @@ -3162,7 +3177,7 @@ int whisper_full(
// sample the next token based on the selected decoding strategy + parameters
// also, update the sliding window position based on the sampled timestamp tokens
{
const auto token = whisper_sample_next_token(ctx, params, prompt, tokens_cur);
const auto token = whisper_sample_next_token(ctx, params, sum_logprobs, prompt, tokens_cur);

// timestamp token - update sliding window
if (token.id > whisper_token_beg(ctx)) {
Expand Down
13 changes: 11 additions & 2 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ extern "C" {

// Token sampling methods.
// These are provided for convenience and can be used after each call to whisper_decode().
// You can also implement your own sampling method using the whisper_get_probs() function.
// You can also implement your own sampling method using the whiper_get_logits() or whisper_get_probs() functions.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
Expand Down Expand Up @@ -192,7 +192,16 @@ extern "C" {
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);

// The probabilities for the next token
// Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);

// Token probabilities (i.e. softmax(logits)) obtained from the last call to whisper_decode()
// The probabilities for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);

// Token Id -> String. Uses the vocabulary in the provided context
Expand Down

0 comments on commit e3c6416

Please sign in to comment.