From 270c09a37f36e13ee986144b2385199d23e030a0 Mon Sep 17 00:00:00 2001
From: stevenlix <38092805+stevenlix@users.noreply.github.com>
Date: Tue, 16 May 2023 21:40:00 -0700
Subject: [PATCH] Add timestamp logits processor for whisper (#15853)
Enable timestamp estimation and logits processing for Whisper model.
---
docs/ContribOperators.md | 4 +-
docs/OperatorKernels.md | 4 +-
.../transformers/beam_search_parameters.cc | 5 +
.../cpu/transformers/generation_shared.h | 3 +
.../cpu/transformers/logits_processor.cc | 122 ++++++++++++++++++
.../cpu/transformers/logits_processor.h | 21 +++
.../core/graph/contrib_ops/contrib_defs.cc | 1 +
.../models/whisper/convert_to_onnx.py | 17 ++-
.../models/whisper/whisper_chain.py | 9 ++
.../test_whisper_timestamp_processor.py | 75 +++++++++++
10 files changed, 254 insertions(+), 7 deletions(-)
create mode 100644 onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 5f856bbe3ca9d..07e892cabdb32 100755
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -428,7 +428,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (5 - 11)
+#### Inputs (5 - 12)
- input_ids : F
@@ -453,6 +453,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- Custom attention mask. Shape is (batch_size, sequence_length)
- decoder_input_ids (optional) : I
- The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)
+- logits_processor (optional) : I
+- Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
#### Outputs (1 - 3)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 7b04f3d6794e2..73b12eeb8ec74 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -422,7 +422,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
-|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)|
+|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
@@ -790,7 +790,7 @@ Do not modify directly.*
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
-|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
+|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
index 55296a01159fe..4e34be98beb2a 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -101,6 +101,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
repetition_penalty = 1.0f;
}
ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty);
+
+ auto* logits_processor_tensor = context->Input(11);
+ logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0;
+ ORT_ENFORCE(logits_processor >= 0,
+ "logits_processor shall be a non-negative integer, got ", logits_processor);
}
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index 93774d0b6b330..51c3ce49d051f 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -126,6 +126,8 @@ struct IGenerationParameters {
static constexpr int kModelTypeT5 = 1;
static constexpr int kModelTypeWhisper = 2;
+ static constexpr int kLogitsProcessorTypeWhisper = 1;
+
// Parameters from node attributes
int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5; 2 for float inputs like Whisper
int eos_token_id;
@@ -143,6 +145,7 @@ struct IGenerationParameters {
float repetition_penalty;
int batch_size; // deduce from first dimension of input_ids
int sequence_length; // deduce from second dimension of input_ids of GPT-2 or decoder_input_ids of T5
+ int logits_processor;
gsl::span vocab_mask;
gsl::span prefix_vocab_mask;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
index d0641fedf978e..9f77c32f0c7cc 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
@@ -238,6 +238,128 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*,
#endif
}
+template
+TimestampLogitsProcessor::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index)
+ : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {}
+
+template
+void TimestampLogitsProcessor::Process(const ISequences* sequences,
+ NextTokenScores& next_token_scores) {
+ const int beg_token_id_ = eos_token_id_ + 107;
+ const int not_token_id_ = eos_token_id_ + 106;
+ const int solm_token_id_ = eos_token_id_ + 105;
+ const int sot_token_id_ = eos_token_id_ + 1;
+ constexpr int translate_token_id_ = 50358;
+ constexpr int transcribe_token_id_ = 50359;
+
+ const int batch_beam_size = next_token_scores.batch_beam_size;
+ const int vocab_size = next_token_scores.vocab_size;
+ for (int i = 0; i < batch_beam_size; i++) {
+ gsl::span beam_token_scores = next_token_scores.GetScores(i);
+ gsl::span sequence = sequences->GetSequence(i);
+ const size_t seq_length = sequence.size();
+
+ // Find first timestamp
+ size_t sample_begin = 0;
+ for (size_t j = 0; j < seq_length; j++) {
+ sample_begin++;
+ if (sequence[j] >= beg_token_id_) {
+ break;
+ }
+ }
+
+ // Suppress tokens
+ for (int j = 0; j < vocab_size; j++) {
+ // Suppress notimestamps and solm tokens
+ if (j == not_token_id_ || j == solm_token_id_) {
+ beam_token_scores[j] = std::numeric_limits::lowest();
+ }
+
+ // Suppress sot, translate and transcribe tokens
+ if (seq_length > sample_begin) {
+ if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) {
+ beam_token_scores[j] = std::numeric_limits::lowest();
+ }
+ }
+ }
+
+ // Timestamps should be in pair except the first one
+ const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_;
+ const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_;
+ if (last_was_timestamp) {
+ if (penultimate_was_timestamp) {
+ // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated
+ for (int j = beg_token_id_; j < vocab_size; j++) {
+ beam_token_scores[j] = std::numeric_limits::lowest();
+ }
+ } else {
+ // If timestamp doesn't show up in pair, generate timestamp
+ for (int j = 0; j < eos_token_id_; j++) {
+ beam_token_scores[j] = std::numeric_limits::lowest();
+ }
+ }
+ }
+
+ // Find timestamp tokens
+ std::vector timestamps;
+ for (const auto& word_id : sequence) {
+ if (word_id >= beg_token_id_) {
+ timestamps.push_back(word_id);
+ }
+ }
+
+ // Timestamps will not decrease
+ const size_t timestamps_len = timestamps.size();
+ if (timestamps_len > 0) {
+ int timestamp_last = 0;
+ if (last_was_timestamp && !penultimate_was_timestamp) {
+ // For single timestamp at the end, next timestamp must not be smaller
+ timestamp_last = timestamps.back();
+ } else {
+ // For paired timestamp at the end, next timestamp must be greater
+ timestamp_last = timestamps.back() + 1;
+ }
+
+ for (int j = beg_token_id_; j < timestamp_last; j++) {
+ beam_token_scores[j] = std::numeric_limits::lowest();
+ }
+ }
+
+ if (seq_length == sample_begin) {
+ const int last_allowed = beg_token_id_ + max_initial_timestamp_index_;
+ for (int j = last_allowed + 1; j < vocab_size; j++) {
+ beam_token_scores[j] = std::numeric_limits::lowest();
+ }
+ }
+
+ // Caculate logsumexp on timestamps
+ float timestamp_logprob = std::numeric_limits::lowest();
+ {
+ float logsumexp = 0.0f;
+ const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end());
+ for (int j = beg_token_id_; j < vocab_size; ++j) {
+ if (beam_token_scores[j] > std::numeric_limits::lowest()) {
+ logsumexp += expf(beam_token_scores[j] - logprob_max);
+ }
+ }
+ if (logsumexp > 0.0f) {
+ timestamp_logprob = logf(logsumexp) + logprob_max;
+ }
+ }
+
+ const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_);
+ if (timestamp_logprob > max_text_token_logprob) {
+ for (int j = 0; j < beg_token_id_; ++j) {
+ beam_token_scores[j] = std::numeric_limits::lowest();
+ }
+ }
+ }
+
+#ifdef DEBUG_GENERATION
+ DumpScores("TimestampLogitsProcessor", next_token_scores);
+#endif
+}
+
void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
LogitsProcessorInitImpl(parameters);
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
index 49a2a5bb324a6..664c497a106d4 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
@@ -138,6 +138,19 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor {
float presence_penalty_;
};
+template
+class TimestampLogitsProcessor : public ILogitsProcessor {
+ public:
+ TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index);
+
+ void Process(const ISequences* sequences,
+ NextTokenScores& next_token_scores) override;
+
+ private:
+ int eos_token_id_;
+ int max_initial_timestamp_index_;
+};
+
class LogitsProcessorList : public ILogitsProcessorList {
public:
LogitsProcessorList() = default;
@@ -193,6 +206,13 @@ class LogitsProcessorList : public ILogitsProcessorList {
processor_list_.push_back(presence_penalty_processor_.get());
}
+ // Add timestamp processor for whisper model
+ if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) {
+ constexpr int max_initial_timestamp_index = 50;
+ timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index);
+ processor_list_.push_back(timestamp_processor_.get());
+ }
+
batch_beam_size_ = parameters.BatchBeamSize();
vocab_size_ = parameters.vocab_size;
}
@@ -208,6 +228,7 @@ class LogitsProcessorList : public ILogitsProcessorList {
std::unique_ptr> min_length_processor_;
std::unique_ptr> temperature_processor_;
std::unique_ptr> presence_penalty_processor_;
+ std::unique_ptr> timestamp_processor_;
};
} // namespace transformers
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 1f209c1a11dec..f4da434c5ea0e 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1096,6 +1096,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
.Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
.Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional)
+ .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
.Output(2, "scores",
diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
index b5a09771b2c64..9beae2cd38ee2 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py
@@ -22,7 +22,7 @@
logger = logging.getLogger("")
-def parse_arguments():
+def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
pretrained_models = PRETRAINED_WHISPER_MODELS
@@ -98,6 +98,15 @@ def parse_arguments():
)
parser.set_defaults(use_forced_decoder_ids=False)
+ parser.add_argument(
+ "-l",
+ "--use_logits_processor",
+ required=False,
+ action="store_true",
+ help="Use logits_processor as an extra graph input to enable specific logits processing",
+ )
+ parser.set_defaults(use_specific_logits_processor=False)
+
parser.add_argument(
"-w",
"--overwrite",
@@ -176,7 +185,7 @@ def parse_arguments():
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
)
- args = parser.parse_args()
+ args = parser.parse_args(argv)
return args
@@ -298,8 +307,8 @@ def export_onnx_models(
return output_paths
-def main():
- args = parse_arguments()
+def main(argv=None):
+ args = parse_arguments(argv)
setup_logger(args.verbose)
diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
index 5a0af8bd6058f..de7134eae505e 100644
--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
+++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py
@@ -33,6 +33,11 @@ def chain_model(args):
]
if args.use_forced_decoder_ids:
beam_inputs.append("decoder_input_ids")
+ else:
+ beam_inputs.append("")
+
+ if args.use_logits_processor:
+ beam_inputs.append("logits_processor")
beam_outputs = ["sequences"]
node = helper.make_node("BeamSearch", inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode")
@@ -74,6 +79,10 @@ def chain_model(args):
)
graph_inputs.append(decoder_input_ids)
+ if args.use_logits_processor:
+ logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
+ graph_inputs.append(logits_processor)
+
# graph outputs
sequences = helper.make_tensor_value_info(
"sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py
new file mode 100644
index 0000000000000..052c5ca264af9
--- /dev/null
+++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py
@@ -0,0 +1,75 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+
+import unittest
+
+import numpy as np
+import pytest
+import torch
+
+from onnxruntime import InferenceSession, SessionOptions
+
+
+class TestTimestampProcessor(unittest.TestCase):
+ def generate_model(self, arguments: str):
+ from onnxruntime.transformers.models.whisper.convert_to_onnx import main as whisper_to_onnx
+
+ whisper_to_onnx(arguments.split())
+
+ def generate_dataset(self):
+ from datasets import load_dataset
+ from transformers import AutoProcessor
+
+ processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+ inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
+ input_features = inputs.input_features
+ return [input_features, processor]
+
+ def run_timestamp(self, provider: str):
+ self.generate_model("-m openai/whisper-tiny --optimize_onnx --precision fp32 -l -e")
+ [input_features, processor] = self.generate_dataset()
+ model_path = "./onnx_models/openai/whisper-tiny_beamsearch.onnx"
+ sess_options = SessionOptions()
+ sess_options.log_severity_level = 4
+ sess = InferenceSession(model_path, sess_options, providers=[provider])
+ input_data = input_features.repeat(1, 1, 1)
+ ort_inputs = {
+ "input_features": np.float32(input_data.cpu().numpy()),
+ "max_length": np.array([128], dtype=np.int32),
+ "min_length": np.array([0], dtype=np.int32),
+ "num_beams": np.array([1], dtype=np.int32),
+ "num_return_sequences": np.array([1], dtype=np.int32),
+ "length_penalty": np.array([1.0], dtype=np.float32),
+ "repetition_penalty": np.array([1.0], dtype=np.float32),
+ "logits_processor": np.array([1], dtype=np.int32),
+ }
+ ort_out = sess.run(None, ort_inputs)
+ ort_out_tensor = torch.from_numpy(ort_out[0])
+ ort_transcription = processor.batch_decode(
+ ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True
+ )
+ expected_transcription = [
+ {
+ "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
+ "offsets": [
+ {
+ "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
+ "timestamp": (0.0, 5.44),
+ }
+ ],
+ }
+ ]
+ self.assertEqual(ort_transcription, expected_transcription)
+
+ @pytest.mark.slow
+ def test_timestamp_cpu(self):
+ provider = "CPUExecutionProvider"
+ self.run_timestamp(provider)
+
+
+if __name__ == "__main__":
+ unittest.main()