Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timestamp logits processor for whisper #15853

Merged
merged 23 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (5 - 10)
#### Inputs (5 - 11)

<dl>
<dt><tt>input_ids</tt> : F</dt>
Expand All @@ -451,6 +451,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
<dt><tt>timestamp_enable</tt> (optional) : tensor(bool)</dt>
Copy link
Contributor

@tianleiwu tianleiwu May 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enable_timestamp? Could we use attribute instead of input?

Copy link
Contributor Author

@stevenlix stevenlix May 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we put timestamp_enable in attributes, users will need to regenerate whisper model when they want to switch the timestamp on and off.

<dd>Enable timestamp processing. True means enabled. Shape is (1). Default is ``False``</dd>
</dl>

#### Outputs (1 - 3)
Expand Down
8 changes: 4 additions & 4 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* timestamp_enable:**tensor(bool)**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**<br> *in* cur_tokens:**T**<br> *in* prev_suffix_match_idx:**T**<br> *in* pred_tokens:**T**<br> *out* tokens:**T**<br> *out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
Expand Down Expand Up @@ -788,7 +788,7 @@ Do not modify directly.*
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* timestamp_enable:**tensor(bool)**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasAdd|*in* X:**T**<br> *in* bias:**T**<br> *in* skip:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand Down Expand Up @@ -1102,7 +1102,8 @@ Do not modify directly.*
|||11+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(float), tensor(float16)|
|||10+|**T** = tensor(float), tensor(float16)|
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|||10+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|Round|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(float), tensor(float16)|
|STFT|*in* signal:**T1**<br> *in* frame_step:**T2**<br> *in* window:**T1**<br> *in* frame_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(int32), tensor(int64)|
|ScaledTanh|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down Expand Up @@ -1192,7 +1193,6 @@ Do not modify directly.*
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
repetition_penalty = 1.0f;
}
ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty);

auto* timestamp_enable_tensor = context->Input<Tensor>(10);
if (timestamp_enable_tensor) {
timestamp_enable = static_cast<bool>(*timestamp_enable_tensor->Data<bool>());
} else {
timestamp_enable = false;
}
}

void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ struct IGenerationParameters {
int num_return_sequences;
float length_penalty;
float repetition_penalty;
bool timestamp_enable;
int batch_size; // deduce from first dimension of input_ids
int sequence_length; // deduce from second dimension of input_ids of GPT-2 or decoder_input_ids of T5

Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,128 @@ void PresencePenaltyLogitsProcessor<T>::Process(const ISequences*,
#endif
}

template <typename T>
TimestampLogitsProcessor<T>::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index)
: eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {}

template <typename T>
void TimestampLogitsProcessor<T>::Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) {
const int beg_token_id_ = eos_token_id_ + 107;
const int not_token_id_ = eos_token_id_ + 106;
const int solm_token_id_ = eos_token_id_ + 105;
const int sot_token_id_ = eos_token_id_ + 1;
constexpr int translate_token_id_ = 50358;
Copy link
Contributor

@tianleiwu tianleiwu May 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constant configuration (these special IDs) shall read from attributes

Copy link
Contributor Author

@stevenlix stevenlix May 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eos_token_id_ is from attribute. Tokens listed here have constant offset to eos_token_id_ and may not need to be provided explicitly in the attributes

constexpr int transcribe_token_id_ = 50359;

const int batch_beam_size = next_token_scores.batch_beam_size;
const int vocab_size = next_token_scores.vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<T> beam_token_scores = next_token_scores.GetScores(i);
gsl::span<const int32_t> sequence = sequences->GetSequence(i);
const size_t seq_length = sequence.size();

// Find first timestamp
size_t sample_begin = 0;
for (size_t j = 0; j < seq_length; j++) {
sample_begin++;
if (sequence[j] >= beg_token_id_) {
break;
}
}

// Suppress tokens
for (int j = 0; j < vocab_size; j++) {
// Suppress notimestamps and solm tokens
if (j == not_token_id_ || j == solm_token_id_) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}

// Suppress sot, translate and transcribe tokens
if (seq_length > sample_begin) {
if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}
}

// Timestamps should be in pair except the first one
const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_;
const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_;
Fixed Show fixed Hide fixed
if (last_was_timestamp) {
if (penultimate_was_timestamp) {
// If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated
for (int j = beg_token_id_; j < vocab_size; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
} else {
// If timestamp doesn't show up in pair, generate timestamp
for (int j = 0; j < eos_token_id_; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}
}

// Find timestamp tokens
std::vector<int32_t> timestamps;
for (const auto& word_id : sequence) {
if (word_id >= beg_token_id_) {
timestamps.push_back(word_id);
}
}

// Timestamps will not decrease
const size_t timestamps_len = timestamps.size();
if (timestamps_len > 0) {
int timestamp_last = 0;
if (last_was_timestamp && !penultimate_was_timestamp) {
// For single timestamp at the end, next timestamp must not be smaller
timestamp_last = timestamps.back();
} else {
// For paired timestamp at the end, next timestamp must be greater
timestamp_last = timestamps.back() + 1;
}

for (int j = beg_token_id_; j < timestamp_last; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}

if (seq_length == sample_begin) {
const int last_allowed = beg_token_id_ + max_initial_timestamp_index_;
for (int j = last_allowed + 1; j < vocab_size; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}

// Caculate logsumexp on timestamps
float timestamp_logprob = std::numeric_limits<T>::lowest();
{
float logsumexp = 0.0f;
const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end());
for (int j = beg_token_id_; j < vocab_size; ++j) {
if (beam_token_scores[j] > std::numeric_limits<T>::lowest()) {
logsumexp += expf(beam_token_scores[j] - logprob_max);
}
}
if (logsumexp > 0.0f) {
timestamp_logprob = logf(logsumexp) + logprob_max;
}
}

const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_);
if (timestamp_logprob > max_text_token_logprob) {
for (int j = 0; j < beg_token_id_; ++j) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}
}

#ifdef DEBUG_GENERATION
DumpScores("TimestampLogitsProcessor", next_token_scores);
#endif
}

void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
LogitsProcessorInitImpl<BeamSearchParameters>(parameters);
}
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor<T> {
float presence_penalty_;
};

template <typename T>
class TimestampLogitsProcessor : public ILogitsProcessor<T> {
public:
TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index);

void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;

private:
int eos_token_id_;
int max_initial_timestamp_index_;
};

class LogitsProcessorList : public ILogitsProcessorList {
public:
LogitsProcessorList() = default;
Expand Down Expand Up @@ -193,6 +206,12 @@ class LogitsProcessorList : public ILogitsProcessorList {
processor_list_.push_back(presence_penalty_processor_.get());
}

if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.timestamp_enable) {
constexpr int max_initial_timestamp_index = 50;
timestamp_processor_ = std::make_unique<TimestampLogitsProcessor<float>>(parameters.eos_token_id, max_initial_timestamp_index);
processor_list_.push_back(timestamp_processor_.get());
}

batch_beam_size_ = parameters.BatchBeamSize();
vocab_size_ = parameters.vocab_size;
}
Expand All @@ -208,6 +227,7 @@ class LogitsProcessorList : public ILogitsProcessorList {
std::unique_ptr<MinLengthLogitsProcessor<float>> min_length_processor_;
std::unique_ptr<TemperatureLogitsProcessor<float>> temperature_processor_;
std::unique_ptr<PresencePenaltyLogitsProcessor<float>> presence_penalty_processor_;
std::unique_ptr<TimestampLogitsProcessor<float>> timestamp_processor_;
};

} // namespace transformers
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
.Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional)
.Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
.Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(10, "timestamp_enable", "Enable timestamp processing. True means enabled. Shape is (1). Default is ``False``", "tensor(bool)", OpSchema::Optional)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can export time stamp processing to a subgraph, then call subgraph for logits processing. In this way, user can add its own logits processing

Copy link
Contributor Author

@stevenlix stevenlix May 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an interesting idea. will think about it. Probably will add it in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another solution is to change the inferface to more general, like logits_processor of tensor(int). So we can use value 1 to indicate time stamp. If other model has its own processing, it could use 2 etc for new processor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
.Output(2, "scores",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = logging.getLogger("")


def parse_arguments():
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()

pretrained_models = PRETRAINED_WHISPER_MODELS
Expand Down Expand Up @@ -166,7 +166,7 @@ def parse_arguments():
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
)

args = parser.parse_args()
args = parser.parse_args(argv)

return args

Expand Down Expand Up @@ -288,8 +288,8 @@ def export_onnx_models(
return output_paths


def main():
args = parse_arguments()
def main(argv=None):
args = parse_arguments(argv)

setup_logger(args.verbose)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def chain_model(args):
"",
"",
"attention_mask",
"timestamp_enable",
]
beam_outputs = ["sequences"]

Expand Down Expand Up @@ -69,6 +70,7 @@ def chain_model(args):
attention_mask = helper.make_tensor_value_info(
"attention_mask", TensorProto.INT32, ["batch_size", "feature_size", "sequence_length"]
)
timestamp_enable = helper.make_tensor_value_info("timestamp_enable", TensorProto.BOOL, [True])

graph_inputs = [
input_features,
Expand All @@ -79,6 +81,7 @@ def chain_model(args):
length_penalty,
repetition_penalty,
attention_mask,
timestamp_enable,
]

# graph outputs
Expand Down
Loading