Skip to content

Commit

Permalink
Add timestamp logits processor for whisper (#15853)
Browse files Browse the repository at this point in the history
Enable timestamp estimation and logits processing for Whisper model.
  • Loading branch information
stevenlix authored May 17, 2023
1 parent f62f722 commit 270c09a
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 7 deletions.
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (5 - 11)
#### Inputs (5 - 12)

<dl>
<dt><tt>input_ids</tt> : F</dt>
Expand All @@ -453,6 +453,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
<dt><tt>decoder_input_ids</tt> (optional) : I</dt>
<dd>The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)</dd>
<dt><tt>logits_processor</tt> (optional) : I</dt>
<dd>Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)</dd>
</dl>

#### Outputs (1 - 3)
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**<br> *in* cur_tokens:**T**<br> *in* prev_suffix_match_idx:**T**<br> *in* pred_tokens:**T**<br> *out* tokens:**T**<br> *out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
Expand Down Expand Up @@ -790,7 +790,7 @@ Do not modify directly.*
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasAdd|*in* X:**T**<br> *in* bias:**T**<br> *in* skip:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
repetition_penalty = 1.0f;
}
ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty);

auto* logits_processor_tensor = context->Input<Tensor>(11);
logits_processor = logits_processor_tensor ? static_cast<int>(*logits_processor_tensor->Data<int32_t>()) : 0;
ORT_ENFORCE(logits_processor >= 0,
"logits_processor shall be a non-negative integer, got ", logits_processor);
}

void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ struct IGenerationParameters {
static constexpr int kModelTypeT5 = 1;
static constexpr int kModelTypeWhisper = 2;

static constexpr int kLogitsProcessorTypeWhisper = 1;

// Parameters from node attributes
int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5; 2 for float inputs like Whisper
int eos_token_id;
Expand All @@ -143,6 +145,7 @@ struct IGenerationParameters {
float repetition_penalty;
int batch_size; // deduce from first dimension of input_ids
int sequence_length; // deduce from second dimension of input_ids of GPT-2 or decoder_input_ids of T5
int logits_processor;

gsl::span<const int32_t> vocab_mask;
gsl::span<const int32_t> prefix_vocab_mask;
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,128 @@ void PresencePenaltyLogitsProcessor<T>::Process(const ISequences*,
#endif
}

template <typename T>
TimestampLogitsProcessor<T>::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index)
: eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {}

template <typename T>
void TimestampLogitsProcessor<T>::Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) {
const int beg_token_id_ = eos_token_id_ + 107;
const int not_token_id_ = eos_token_id_ + 106;
const int solm_token_id_ = eos_token_id_ + 105;
const int sot_token_id_ = eos_token_id_ + 1;
constexpr int translate_token_id_ = 50358;
constexpr int transcribe_token_id_ = 50359;

const int batch_beam_size = next_token_scores.batch_beam_size;
const int vocab_size = next_token_scores.vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<T> beam_token_scores = next_token_scores.GetScores(i);
gsl::span<const int32_t> sequence = sequences->GetSequence(i);
const size_t seq_length = sequence.size();

// Find first timestamp
size_t sample_begin = 0;
for (size_t j = 0; j < seq_length; j++) {
sample_begin++;
if (sequence[j] >= beg_token_id_) {
break;
}
}

// Suppress tokens
for (int j = 0; j < vocab_size; j++) {
// Suppress notimestamps and solm tokens
if (j == not_token_id_ || j == solm_token_id_) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}

// Suppress sot, translate and transcribe tokens
if (seq_length > sample_begin) {
if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}
}

// Timestamps should be in pair except the first one
const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_;
const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_;
if (last_was_timestamp) {
if (penultimate_was_timestamp) {
// If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated
for (int j = beg_token_id_; j < vocab_size; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
} else {
// If timestamp doesn't show up in pair, generate timestamp
for (int j = 0; j < eos_token_id_; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}
}

// Find timestamp tokens
std::vector<int32_t> timestamps;
for (const auto& word_id : sequence) {
if (word_id >= beg_token_id_) {
timestamps.push_back(word_id);
}
}

// Timestamps will not decrease
const size_t timestamps_len = timestamps.size();
if (timestamps_len > 0) {
int timestamp_last = 0;
if (last_was_timestamp && !penultimate_was_timestamp) {
// For single timestamp at the end, next timestamp must not be smaller
timestamp_last = timestamps.back();
} else {
// For paired timestamp at the end, next timestamp must be greater
timestamp_last = timestamps.back() + 1;
}

for (int j = beg_token_id_; j < timestamp_last; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}

if (seq_length == sample_begin) {
const int last_allowed = beg_token_id_ + max_initial_timestamp_index_;
for (int j = last_allowed + 1; j < vocab_size; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}

// Caculate logsumexp on timestamps
float timestamp_logprob = std::numeric_limits<T>::lowest();
{
float logsumexp = 0.0f;
const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end());
for (int j = beg_token_id_; j < vocab_size; ++j) {
if (beam_token_scores[j] > std::numeric_limits<T>::lowest()) {
logsumexp += expf(beam_token_scores[j] - logprob_max);
}
}
if (logsumexp > 0.0f) {
timestamp_logprob = logf(logsumexp) + logprob_max;
}
}

const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_);
if (timestamp_logprob > max_text_token_logprob) {
for (int j = 0; j < beg_token_id_; ++j) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}
}

#ifdef DEBUG_GENERATION
DumpScores("TimestampLogitsProcessor", next_token_scores);
#endif
}

void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
LogitsProcessorInitImpl<BeamSearchParameters>(parameters);
}
Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor<T> {
float presence_penalty_;
};

template <typename T>
class TimestampLogitsProcessor : public ILogitsProcessor<T> {
public:
TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index);

void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;

private:
int eos_token_id_;
int max_initial_timestamp_index_;
};

class LogitsProcessorList : public ILogitsProcessorList {
public:
LogitsProcessorList() = default;
Expand Down Expand Up @@ -193,6 +206,13 @@ class LogitsProcessorList : public ILogitsProcessorList {
processor_list_.push_back(presence_penalty_processor_.get());
}

// Add timestamp processor for whisper model
if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) {
constexpr int max_initial_timestamp_index = 50;
timestamp_processor_ = std::make_unique<TimestampLogitsProcessor<float>>(parameters.eos_token_id, max_initial_timestamp_index);
processor_list_.push_back(timestamp_processor_.get());
}

batch_beam_size_ = parameters.BatchBeamSize();
vocab_size_ = parameters.vocab_size;
}
Expand All @@ -208,6 +228,7 @@ class LogitsProcessorList : public ILogitsProcessorList {
std::unique_ptr<MinLengthLogitsProcessor<float>> min_length_processor_;
std::unique_ptr<TemperatureLogitsProcessor<float>> temperature_processor_;
std::unique_ptr<PresencePenaltyLogitsProcessor<float>> presence_penalty_processor_;
std::unique_ptr<TimestampLogitsProcessor<float>> timestamp_processor_;
};

} // namespace transformers
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
.Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
.Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional)
.Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
.Output(2, "scores",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = logging.getLogger("")


def parse_arguments():
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()

pretrained_models = PRETRAINED_WHISPER_MODELS
Expand Down Expand Up @@ -98,6 +98,15 @@ def parse_arguments():
)
parser.set_defaults(use_forced_decoder_ids=False)

parser.add_argument(
"-l",
"--use_logits_processor",
required=False,
action="store_true",
help="Use logits_processor as an extra graph input to enable specific logits processing",
)
parser.set_defaults(use_specific_logits_processor=False)

parser.add_argument(
"-w",
"--overwrite",
Expand Down Expand Up @@ -176,7 +185,7 @@ def parse_arguments():
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
)

args = parser.parse_args()
args = parser.parse_args(argv)

return args

Expand Down Expand Up @@ -298,8 +307,8 @@ def export_onnx_models(
return output_paths


def main():
args = parse_arguments()
def main(argv=None):
args = parse_arguments(argv)

setup_logger(args.verbose)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def chain_model(args):
]
if args.use_forced_decoder_ids:
beam_inputs.append("decoder_input_ids")
else:
beam_inputs.append("")

if args.use_logits_processor:
beam_inputs.append("logits_processor")
beam_outputs = ["sequences"]

node = helper.make_node("BeamSearch", inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode")
Expand Down Expand Up @@ -74,6 +79,10 @@ def chain_model(args):
)
graph_inputs.append(decoder_input_ids)

if args.use_logits_processor:
logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
graph_inputs.append(logits_processor)

# graph outputs
sequences = helper.make_tensor_value_info(
"sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import unittest

import numpy as np
import pytest
import torch

from onnxruntime import InferenceSession, SessionOptions


class TestTimestampProcessor(unittest.TestCase):
def generate_model(self, arguments: str):
from onnxruntime.transformers.models.whisper.convert_to_onnx import main as whisper_to_onnx

whisper_to_onnx(arguments.split())

def generate_dataset(self):
from datasets import load_dataset
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = inputs.input_features
return [input_features, processor]

def run_timestamp(self, provider: str):
self.generate_model("-m openai/whisper-tiny --optimize_onnx --precision fp32 -l -e")
[input_features, processor] = self.generate_dataset()
model_path = "./onnx_models/openai/whisper-tiny_beamsearch.onnx"
sess_options = SessionOptions()
sess_options.log_severity_level = 4
sess = InferenceSession(model_path, sess_options, providers=[provider])
input_data = input_features.repeat(1, 1, 1)
ort_inputs = {
"input_features": np.float32(input_data.cpu().numpy()),
"max_length": np.array([128], dtype=np.int32),
"min_length": np.array([0], dtype=np.int32),
"num_beams": np.array([1], dtype=np.int32),
"num_return_sequences": np.array([1], dtype=np.int32),
"length_penalty": np.array([1.0], dtype=np.float32),
"repetition_penalty": np.array([1.0], dtype=np.float32),
"logits_processor": np.array([1], dtype=np.int32),
}
ort_out = sess.run(None, ort_inputs)
ort_out_tensor = torch.from_numpy(ort_out[0])
ort_transcription = processor.batch_decode(
ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True
)
expected_transcription = [
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"offsets": [
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (0.0, 5.44),
}
],
}
]
self.assertEqual(ort_transcription, expected_transcription)

@pytest.mark.slow
def test_timestamp_cpu(self):
provider = "CPUExecutionProvider"
self.run_timestamp(provider)


if __name__ == "__main__":
unittest.main()

0 comments on commit 270c09a

Please sign in to comment.