-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add timestamp logits processor for whisper #15853
Changes from 10 commits
01ad1da
9a3e611
90eff04
01fb75b
29f331c
9ef5d4a
86339ab
5be695b
cd192db
58a9538
bb5113a
22610f2
72e90ca
c207b16
bd6f079
7d4e2d7
7048cdd
835ba50
dffcfeb
8ff2600
33ec665
c2f1886
2ac6d9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -238,6 +238,128 @@ void PresencePenaltyLogitsProcessor<T>::Process(const ISequences*, | |
#endif | ||
} | ||
|
||
template <typename T> | ||
TimestampLogitsProcessor<T>::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) | ||
: eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} | ||
|
||
template <typename T> | ||
void TimestampLogitsProcessor<T>::Process(const ISequences* sequences, | ||
NextTokenScores<T>& next_token_scores) { | ||
const int beg_token_id_ = eos_token_id_ + 107; | ||
const int not_token_id_ = eos_token_id_ + 106; | ||
const int solm_token_id_ = eos_token_id_ + 105; | ||
const int sot_token_id_ = eos_token_id_ + 1; | ||
constexpr int translate_token_id_ = 50358; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These constant configuration (these special IDs) shall read from attributes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. eos_token_id_ is from attribute. Tokens listed here have constant offset to eos_token_id_ and may not need to be provided explicitly in the attributes |
||
constexpr int transcribe_token_id_ = 50359; | ||
|
||
const int batch_beam_size = next_token_scores.batch_beam_size; | ||
const int vocab_size = next_token_scores.vocab_size; | ||
for (int i = 0; i < batch_beam_size; i++) { | ||
gsl::span<T> beam_token_scores = next_token_scores.GetScores(i); | ||
gsl::span<const int32_t> sequence = sequences->GetSequence(i); | ||
const size_t seq_length = sequence.size(); | ||
|
||
// Find first timestamp | ||
size_t sample_begin = 0; | ||
for (size_t j = 0; j < seq_length; j++) { | ||
sample_begin++; | ||
if (sequence[j] >= beg_token_id_) { | ||
break; | ||
} | ||
} | ||
|
||
// Suppress tokens | ||
for (int j = 0; j < vocab_size; j++) { | ||
// Suppress notimestamps and solm tokens | ||
if (j == not_token_id_ || j == solm_token_id_) { | ||
beam_token_scores[j] = std::numeric_limits<T>::lowest(); | ||
} | ||
|
||
// Suppress sot, translate and transcribe tokens | ||
if (seq_length > sample_begin) { | ||
if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { | ||
beam_token_scores[j] = std::numeric_limits<T>::lowest(); | ||
} | ||
} | ||
} | ||
|
||
// Timestamps should be in pair except the first one | ||
const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; | ||
const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; | ||
|
||
if (last_was_timestamp) { | ||
if (penultimate_was_timestamp) { | ||
// If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated | ||
for (int j = beg_token_id_; j < vocab_size; j++) { | ||
beam_token_scores[j] = std::numeric_limits<T>::lowest(); | ||
} | ||
} else { | ||
// If timestamp doesn't show up in pair, generate timestamp | ||
for (int j = 0; j < eos_token_id_; j++) { | ||
beam_token_scores[j] = std::numeric_limits<T>::lowest(); | ||
} | ||
} | ||
} | ||
|
||
// Find timestamp tokens | ||
std::vector<int32_t> timestamps; | ||
for (const auto& word_id : sequence) { | ||
if (word_id >= beg_token_id_) { | ||
timestamps.push_back(word_id); | ||
} | ||
} | ||
|
||
// Timestamps will not decrease | ||
const size_t timestamps_len = timestamps.size(); | ||
if (timestamps_len > 0) { | ||
int timestamp_last = 0; | ||
if (last_was_timestamp && !penultimate_was_timestamp) { | ||
// For single timestamp at the end, next timestamp must not be smaller | ||
timestamp_last = timestamps.back(); | ||
} else { | ||
// For paired timestamp at the end, next timestamp must be greater | ||
timestamp_last = timestamps.back() + 1; | ||
} | ||
|
||
for (int j = beg_token_id_; j < timestamp_last; j++) { | ||
beam_token_scores[j] = std::numeric_limits<T>::lowest(); | ||
} | ||
} | ||
|
||
if (seq_length == sample_begin) { | ||
const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; | ||
for (int j = last_allowed + 1; j < vocab_size; j++) { | ||
beam_token_scores[j] = std::numeric_limits<T>::lowest(); | ||
} | ||
} | ||
|
||
// Caculate logsumexp on timestamps | ||
float timestamp_logprob = std::numeric_limits<T>::lowest(); | ||
{ | ||
float logsumexp = 0.0f; | ||
const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); | ||
for (int j = beg_token_id_; j < vocab_size; ++j) { | ||
if (beam_token_scores[j] > std::numeric_limits<T>::lowest()) { | ||
logsumexp += expf(beam_token_scores[j] - logprob_max); | ||
} | ||
} | ||
if (logsumexp > 0.0f) { | ||
timestamp_logprob = logf(logsumexp) + logprob_max; | ||
} | ||
} | ||
|
||
const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); | ||
if (timestamp_logprob > max_text_token_logprob) { | ||
for (int j = 0; j < beg_token_id_; ++j) { | ||
beam_token_scores[j] = std::numeric_limits<T>::lowest(); | ||
} | ||
} | ||
} | ||
|
||
#ifdef DEBUG_GENERATION | ||
DumpScores("TimestampLogitsProcessor", next_token_scores); | ||
#endif | ||
} | ||
|
||
void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { | ||
LogitsProcessorInitImpl<BeamSearchParameters>(parameters); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1095,6 +1095,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, | |
.Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) | ||
.Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) | ||
.Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) | ||
.Input(10, "timestamp_enable", "Enable timestamp processing. True means enabled. Shape is (1). Default is ``False``", "tensor(bool)", OpSchema::Optional) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can export time stamp processing to a subgraph, then call subgraph for logits processing. In this way, user can add its own logits processing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's an interesting idea. will think about it. Probably will add it in a separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another solution is to change the inferface to more general, like logits_processor of tensor(int). So we can use value 1 to indicate time stamp. If other model has its own processing, it could use 2 etc for new processor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") | ||
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) | ||
.Output(2, "scores", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about enable_timestamp? Could we use attribute instead of input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we put timestamp_enable in attributes, users will need to regenerate whisper model when they want to switch the timestamp on and off.