Skip to content

Commit

Permalink
Add heuristic mode
Browse files Browse the repository at this point in the history
  • Loading branch information
bobqianic authored Feb 9, 2024
1 parent c0277e3 commit b6d89b0
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 20 deletions.
64 changes: 45 additions & 19 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2845,7 +2845,6 @@ static bool log_mel_spectrogram(
const int n_mel,
const int n_threads,
const whisper_filters & filters,
const bool debug,
whisper_mel & mel) {
const int64_t t_start_us = ggml_time_us();

Expand Down Expand Up @@ -2917,17 +2916,6 @@ static bool log_mel_spectrogram(

wstate.t_mel_us += ggml_time_us() - t_start_us;

// Dump log_mel_spectrogram
if (debug) {
std::ofstream outFile("log_mel_spectrogram.json");
outFile << "[";
for (uint64_t i = 0; i < mel.data.size() - 1; i++) {
outFile << mel.data[i] << ", ";
}
outFile << mel.data[mel.data.size() - 1] << "]";
outFile.close();
}

return true;
}

Expand Down Expand Up @@ -3598,7 +3586,7 @@ void whisper_free_params(struct whisper_full_params * params) {
}

int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, state->mel)) {
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
Expand All @@ -3612,7 +3600,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int

// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good)
int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) {
if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, state->mel)) {
WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__);
return -1;
}
Expand Down Expand Up @@ -4530,7 +4518,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.max_tokens =*/ 0,

/*.speed_up =*/ false,
/*.debug_mode =*/ false,
/*.audio_ctx =*/ 0,

/*.tdrz_enable =*/ false,
Expand All @@ -4544,6 +4531,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ true,
/*.heuristic =*/ true,

/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
Expand Down Expand Up @@ -4785,7 +4773,7 @@ static void whisper_no_speech_probs(
}

static const std::vector<std::string> non_speech_tokens = {
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
"\"", "#", "*", "+", "/", ":", ";", "<", "=", ">", "@", "\\", "^",
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
"♪♪♪","", "", "", "", "", "", ""
Expand Down Expand Up @@ -4967,14 +4955,13 @@ static void whisper_process_logits(
float timestamp_logprob = -INFINITY;
{
float logsumexp = 0.0f;
const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end());
for (int i = vocab.token_beg; i < n_logits; ++i) {
if (logprobs[i] > -INFINITY) {
logsumexp += expf(logprobs[i] - logprob_max);
logsumexp += expf(logprobs[i]);
}
}
if (logsumexp > 0.0f) {
timestamp_logprob = logf(logsumexp) + logprob_max;
timestamp_logprob = logf(logsumexp);
}
}

Expand Down Expand Up @@ -5435,6 +5422,7 @@ int whisper_full_with_state(
}

int seek = seek_start;
bool fast_forward = false;

std::vector<whisper_token> prompt;
prompt.reserve(whisper_n_text_ctx(ctx));
Expand Down Expand Up @@ -5991,6 +5979,42 @@ int whisper_full_with_state(
if (best_decoder.sequence.no_speech_probs > params.no_speech_thold) {
if (best_decoder.sequence.avg_logprobs < params.logprob_thold) {
// fast-forward to the next segment boundary
prompt_past.clear();
fast_forward = true;
seek += std::min(3000, state->mel.n_len_org - seek);
continue;
}
}
}

// repetition check
{
if (!params.no_timestamps) {
const auto & best_decoder = state->decoders[best_decoder_id];
const auto & tokens_cur = best_decoder.sequence.tokens;

std::set<std::string> table;
std::string text;
int timestamp_token_counter = 0;
int max_length = 0;

for (auto & token : tokens_cur) {
if (token.id < whisper_token_beg(ctx)) {
text += ctx->vocab.id_to_token[token.id];
} else {
timestamp_token_counter++;
}
if (timestamp_token_counter % 2 == 0) {
if (text.length() > max_length) {max_length = text.length();}
table.insert(text);
text.clear();
}
}

if ((static_cast<float>(table.size()) / static_cast<float>(timestamp_token_counter)) < 0.25 || max_length <= 4) {
// fast-forward to the next segment boundary
prompt_past.clear();
fast_forward = true;
seek += std::min(3000, state->mel.n_len_org - seek);
continue;
}
Expand Down Expand Up @@ -6113,6 +6137,8 @@ int whisper_full_with_state(
}
}

fast_forward = false;

// update audio window
// https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/transcribe.py#L353-L361
{
Expand Down
2 changes: 1 addition & 1 deletion whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,6 @@ extern "C" {
// [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel)
int audio_ctx; // overwrite the audio context size (0 = use default)

// [EXPERIMENTAL] [TDRZ] tinydiarize
Expand All @@ -476,6 +475,7 @@ extern "C" {
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
bool heuristic;

float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
Expand Down

0 comments on commit b6d89b0

Please sign in to comment.