Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
bobqianic authored Feb 5, 2024
1 parent e2e5177 commit 8a46034
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 59 deletions.
204 changes: 147 additions & 57 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ struct whisper_segment {

std::vector<whisper_token_data> tokens;

double no_speech_probs;

bool speaker_turn_next;
};

Expand Down Expand Up @@ -4525,7 +4527,6 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.split_on_word =*/ false,
/*.max_tokens =*/ 0,

/*.speed_up =*/ false,
Expand All @@ -4542,7 +4543,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.detect_language =*/ false,

/*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false,
/*.suppress_non_speech_tokens =*/ true,

/*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f,
Expand Down Expand Up @@ -4613,65 +4614,144 @@ static void whisper_exp_compute_token_level_timestamps(
float thold_pt,
float thold_ptsum);

//static inline bool should_split_on_word(const char * txt, bool split_on_word) {
// if (!split_on_word) return true;
//
// return txt[0] == ' ';
//}
static bool whisper_utf8_is_valid(const std::string &str) {
uint64_t count = 0; // Count of bytes in the current UTF-8 character

for (unsigned char c : str) {
if (count == 0) {
if ((c >> 5) == 0b110) count = 1; // 2-byte character
else if ((c >> 4) == 0b1110) count = 2; // 3-byte character
else if ((c >> 3) == 0b11110) count = 3; // 4-byte character
else if ((c >> 7) == 0b0) count = 0; // 1-byte character
else return false; // Invalid UTF-8
} else {
if ((c >> 6) != 0b10) return false; // Subsequent bytes should start with 10
count--;
}
}

return count == 0; // Ensure all UTF-8 characters are complete
}

static bool whisper_utf8_is_valid(const char * str) {
std::string new_str(str);
return whisper_utf8_is_valid(new_str);
}

static std::vector<whisper_pair<std::string, bool>> whisper_utf8_merge_and_split(const std::string &str) {
std::vector<whisper_pair<std::string, bool>> result;
std::string buffer;
uint64_t count = 0; // Count of bytes in the current UTF-8 character

for (unsigned char c : str) {
if (count == 0) {
header:
if ((c >> 5) == 0b110) count = 1; // 2-byte character
else if ((c >> 4) == 0b1110) count = 2; // 3-byte character
else if ((c >> 3) == 0b11110) count = 3; // 4-byte character
else count = 0; // Invalid UTF-8 || 1-byte character
if (!buffer.empty()) result.emplace_back(buffer, true);
buffer.clear();
buffer += static_cast<char>(c);
} else {
if ((c >> 6) != 0b10) {
goto header;
} // Subsequent bytes should start with 10
buffer += static_cast<char>(c);
count--;
}
}

if (!buffer.empty()) result.emplace_back(buffer, false);
return result;
}

static std::vector<whisper_segment> whisper_split_tokens_on_utf8(struct whisper_context & ctx, whisper_segment & segment, bool special) {
std::vector<whisper_segment> words;

std::string text;
std::vector<whisper_token_data> raw;
int64_t t0 = -1;
int64_t t1 = -1;

for (const auto & token : segment.tokens) {
if (special == false && token.id >= whisper_token_beg(&ctx)) {
continue;
}
if (t0 < 0) {t0 = token.t0;}
t1 = token.t1;
text += whisper_token_to_str(&ctx, token.id);
raw.push_back(token);

if (whisper_utf8_is_valid(text)) {
words.push_back({t0, t1, text, raw, segment.no_speech_probs, segment.speaker_turn_next});
t0 = -1;
t1 = -1;
raw.clear();
text = "";
}
}

return words;
}

// wrap the last segment into segments with max_len number of words
// returns the number of new segments
static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) {
static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool special) {
const static std::set<std::string> unicode_language = {"zh", "ja", "th", "lo", "my", "yue"};
const static std::string punctuation = R"(!"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~)";

auto segment = state.result_all.back();

int res = 1;
int acc = 0;
std::vector<whisper_segment> words;

std::string text;
if (unicode_language.find(whisper_lang_str(ctx.state->lang_id)) != unicode_language.end()) {
// split on utf-8
words = whisper_split_tokens_on_utf8(ctx, segment, special);
} else {
// split on spaces and punctuation
auto subwords = whisper_split_tokens_on_utf8(ctx, segment, special);

// for (int i = 0; i < (int) segment.tokens.size(); i++) {
// const auto & token = segment.tokens[i];
// if (token.id >= whisper_token_eot(&ctx)) {
// continue;
// }
//
// const auto txt = whisper_token_to_str(&ctx, token.id);
// const int cur = strlen(txt);
//
// if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) {
// state.result_all.back().text = std::move(text);
// state.result_all.back().t1 = token.t0;
// state.result_all.back().tokens.resize(i);
// state.result_all.back().speaker_turn_next = false;
//
// state.result_all.push_back({});
// state.result_all.back().t0 = token.t0;
// state.result_all.back().t1 = segment.t1;
//
// // add tokens [i, end] to the new segment
// state.result_all.back().tokens.insert(
// state.result_all.back().tokens.end(),
// segment.tokens.begin() + i,
// segment.tokens.end());
//
// state.result_all.back().speaker_turn_next = segment.speaker_turn_next;
//
// acc = 0;
// text = "";
//
// segment = state.result_all.back();
// i = -1;
//
// res++;
// } else {
// acc += cur;
// text += txt;
// }
// }
//
// state.result_all.back().text = std::move(text);
for (auto & subword : subwords) {
if (subword.tokens[0].id >= whisper_token_beg(&ctx) || subword.text[0] == ' ' || punctuation.find(subword.text) != std::string::npos) {
words.push_back(subword);
} else {
words.back().t1 = subword.t1;
words.back().text += subword.text;
words.back().tokens.insert(words.back().tokens.end(), subword.tokens.begin(), subword.tokens.end());
}
}
}

return res;
state.result_all.pop_back();

if (max_len == 1) {
state.result_all.insert(state.result_all.end(), words.begin(), words.end());
return static_cast<int>(words.size());
} else {
int acc = 0;
int n_new = 0;
whisper_segment temp = {};

for (auto & word : words) {
if (acc == 0) {temp.t0 = word.t0;}
temp.t1 = word.t1;
temp.text += word.text;
temp.tokens.insert(temp.tokens.end(), word.tokens.begin(), word.tokens.end());
temp.speaker_turn_next = word.speaker_turn_next;
temp.no_speech_probs = word.no_speech_probs;

if (acc + 1 >= max_len) {
state.result_all.push_back(temp);
temp = {};
acc = 0;
n_new ++;
} else {
acc++;
}
}
return n_new;
}
}

// ref: https://github.com/openai/whisper/blob/ba3f3cd54b0e5b8ce1ab3de13e32122d0d5f98ab/whisper/decoding.py#L689-L693
Expand Down Expand Up @@ -5927,6 +6007,7 @@ int whisper_full_with_state(

const auto seek_delta = best_decoder.seek_delta;
const auto result_len = best_decoder.sequence.result_len;
const auto non_speech_probs = best_decoder.sequence.no_speech_probs;

const auto & tokens_cur = best_decoder.sequence.tokens;

Expand Down Expand Up @@ -5965,18 +6046,19 @@ int whisper_full_with_state(
auto text_callback = [&](int t1, int token_offset, int end) {
int n_new = 1;

result_all.push_back({ t0, t1, text, {} , speaker_turn_next });
result_all.push_back({ t0, t1, text, {}, non_speech_probs, speaker_turn_next });
for (int j = std::max(0, token_offset); j <= end; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}

if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(*ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
}

if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word);
}
if (params.max_len > 0) {
n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.print_special);
}

if (params.new_segment_callback) {
params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data);
}
Expand Down Expand Up @@ -6192,6 +6274,14 @@ int whisper_full_lang_id(struct whisper_context * ctx) {
return ctx->state->lang_id;
}

double whisper_full_get_segment_no_speech_probs_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].no_speech_probs;
}

double whisper_full_get_segment_no_speech_probs(struct whisper_context * ctx, int i_segment) {
return ctx->state->result_all[i_segment].no_speech_probs;
}

int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) {
return state->result_all[i_segment].t0;
}
Expand Down
11 changes: 9 additions & 2 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,7 @@ extern "C" {
bool token_timestamps; // enable token-level timestamps
float thold_pt; // timestamp token probability threshold (~0.01)
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
int max_len; // max segment length in characters
bool split_on_word; // split on word rather than on token (when used with max_len)
int max_len; // max segment length in characters (0 = no limit)
int max_tokens; // max tokens per segment (0 = no limit)

// [EXPERIMENTAL] speed-up techniques
Expand Down Expand Up @@ -570,6 +569,10 @@ extern "C" {
// Language id associated with the provided state
WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state);

// Get the no speech probability of the specified segment
WHISPER_API double whisper_full_get_segment_no_speech_probs (struct whisper_context * ctx, int i_segment);
WHISPER_API double whisper_full_get_segment_no_speech_probs_from_state(struct whisper_state * state, int i_segment);

// Get the start and end time of the specified segment
WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment);
WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment);
Expand Down Expand Up @@ -605,6 +608,10 @@ extern "C" {
WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token);
WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token);

// Check if the string is valid UTF-8
WHISPER_API bool whisper_utf8_is_valid(const char * str);


////////////////////////////////////////////////////////////////////////////

// Temporary helpers needed for exposing ggml interface
Expand Down

0 comments on commit 8a46034

Please sign in to comment.