From 270c09a37f36e13ee986144b2385199d23e030a0 Mon Sep 17 00:00:00 2001 From: stevenlix <38092805+stevenlix@users.noreply.github.com> Date: Tue, 16 May 2023 21:40:00 -0700 Subject: [PATCH] Add timestamp logits processor for whisper (#15853) Enable timestamp estimation and logits processing for Whisper model. --- docs/ContribOperators.md | 4 +- docs/OperatorKernels.md | 4 +- .../transformers/beam_search_parameters.cc | 5 + .../cpu/transformers/generation_shared.h | 3 + .../cpu/transformers/logits_processor.cc | 122 ++++++++++++++++++ .../cpu/transformers/logits_processor.h | 21 +++ .../core/graph/contrib_ops/contrib_defs.cc | 1 + .../models/whisper/convert_to_onnx.py | 17 ++- .../models/whisper/whisper_chain.py | 9 ++ .../test_whisper_timestamp_processor.py | 75 +++++++++++ 10 files changed, 254 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 5f856bbe3ca9d..07e892cabdb32 100755 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -428,7 +428,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (5 - 11) +#### Inputs (5 - 12)
input_ids : F
@@ -453,6 +453,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Custom attention mask. Shape is (batch_size, sequence_length)
decoder_input_ids (optional) : I
The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)
+
logits_processor (optional) : I
+
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
#### Outputs (1 - 3) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 7b04f3d6794e2..73b12eeb8ec74 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -422,7 +422,7 @@ Do not modify directly.* |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)| |AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)| -|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| +|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)| |BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)| |CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)| @@ -790,7 +790,7 @@ Do not modify directly.* | | |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| -|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| +|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 55296a01159fe..4e34be98beb2a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -101,6 +101,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { repetition_penalty = 1.0f; } ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty); + + auto* logits_processor_tensor = context->Input(11); + logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0; + ORT_ENFORCE(logits_processor >= 0, + "logits_processor shall be a non-negative integer, got ", logits_processor); } void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 93774d0b6b330..51c3ce49d051f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -126,6 +126,8 @@ struct IGenerationParameters { static constexpr int kModelTypeT5 = 1; static constexpr int kModelTypeWhisper = 2; + static constexpr int kLogitsProcessorTypeWhisper = 1; + // Parameters from node attributes int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5; 2 for float inputs like Whisper int eos_token_id; @@ -143,6 +145,7 @@ struct IGenerationParameters { float repetition_penalty; int batch_size; // deduce from first dimension of input_ids int sequence_length; // deduce from second dimension of input_ids of GPT-2 or decoder_input_ids of T5 + int logits_processor; gsl::span vocab_mask; gsl::span prefix_vocab_mask; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index d0641fedf978e..9f77c32f0c7cc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -238,6 +238,128 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, #endif } +template +TimestampLogitsProcessor::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) + : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} + +template +void TimestampLogitsProcessor::Process(const ISequences* sequences, + NextTokenScores& next_token_scores) { + const int beg_token_id_ = eos_token_id_ + 107; + const int not_token_id_ = eos_token_id_ + 106; + const int solm_token_id_ = eos_token_id_ + 105; + const int sot_token_id_ = eos_token_id_ + 1; + constexpr int translate_token_id_ = 50358; + constexpr int transcribe_token_id_ = 50359; + + const int batch_beam_size = next_token_scores.batch_beam_size; + const int vocab_size = next_token_scores.vocab_size; + for (int i = 0; i < batch_beam_size; i++) { + gsl::span beam_token_scores = next_token_scores.GetScores(i); + gsl::span sequence = sequences->GetSequence(i); + const size_t seq_length = sequence.size(); + + // Find first timestamp + size_t sample_begin = 0; + for (size_t j = 0; j < seq_length; j++) { + sample_begin++; + if (sequence[j] >= beg_token_id_) { + break; + } + } + + // Suppress tokens + for (int j = 0; j < vocab_size; j++) { + // Suppress notimestamps and solm tokens + if (j == not_token_id_ || j == solm_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + + // Suppress sot, translate and transcribe tokens + if (seq_length > sample_begin) { + if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Timestamps should be in pair except the first one + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated + for (int j = beg_token_id_; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } else { + // If timestamp doesn't show up in pair, generate timestamp + for (int j = 0; j < eos_token_id_; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + + // Find timestamp tokens + std::vector timestamps; + for (const auto& word_id : sequence) { + if (word_id >= beg_token_id_) { + timestamps.push_back(word_id); + } + } + + // Timestamps will not decrease + const size_t timestamps_len = timestamps.size(); + if (timestamps_len > 0) { + int timestamp_last = 0; + if (last_was_timestamp && !penultimate_was_timestamp) { + // For single timestamp at the end, next timestamp must not be smaller + timestamp_last = timestamps.back(); + } else { + // For paired timestamp at the end, next timestamp must be greater + timestamp_last = timestamps.back() + 1; + } + + for (int j = beg_token_id_; j < timestamp_last; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + if (seq_length == sample_begin) { + const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + for (int j = last_allowed + 1; j < vocab_size; j++) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + + // Caculate logsumexp on timestamps + float timestamp_logprob = std::numeric_limits::lowest(); + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); + for (int j = beg_token_id_; j < vocab_size; ++j) { + if (beam_token_scores[j] > std::numeric_limits::lowest()) { + logsumexp += expf(beam_token_scores[j] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + if (timestamp_logprob > max_text_token_logprob) { + for (int j = 0; j < beg_token_id_; ++j) { + beam_token_scores[j] = std::numeric_limits::lowest(); + } + } + } + +#ifdef DEBUG_GENERATION + DumpScores("TimestampLogitsProcessor", next_token_scores); +#endif +} + void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { LogitsProcessorInitImpl(parameters); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 49a2a5bb324a6..664c497a106d4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -138,6 +138,19 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { float presence_penalty_; }; +template +class TimestampLogitsProcessor : public ILogitsProcessor { + public: + TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + int eos_token_id_; + int max_initial_timestamp_index_; +}; + class LogitsProcessorList : public ILogitsProcessorList { public: LogitsProcessorList() = default; @@ -193,6 +206,13 @@ class LogitsProcessorList : public ILogitsProcessorList { processor_list_.push_back(presence_penalty_processor_.get()); } + // Add timestamp processor for whisper model + if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { + constexpr int max_initial_timestamp_index = 50; + timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index); + processor_list_.push_back(timestamp_processor_.get()); + } + batch_beam_size_ = parameters.BatchBeamSize(); vocab_size_ = parameters.vocab_size; } @@ -208,6 +228,7 @@ class LogitsProcessorList : public ILogitsProcessorList { std::unique_ptr> min_length_processor_; std::unique_ptr> temperature_processor_; std::unique_ptr> presence_penalty_processor_; + std::unique_ptr> timestamp_processor_; }; } // namespace transformers diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 1f209c1a11dec..f4da434c5ea0e 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1096,6 +1096,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) + .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index b5a09771b2c64..9beae2cd38ee2 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -22,7 +22,7 @@ logger = logging.getLogger("") -def parse_arguments(): +def parse_arguments(argv=None): parser = argparse.ArgumentParser() pretrained_models = PRETRAINED_WHISPER_MODELS @@ -98,6 +98,15 @@ def parse_arguments(): ) parser.set_defaults(use_forced_decoder_ids=False) + parser.add_argument( + "-l", + "--use_logits_processor", + required=False, + action="store_true", + help="Use logits_processor as an extra graph input to enable specific logits processing", + ) + parser.set_defaults(use_specific_logits_processor=False) + parser.add_argument( "-w", "--overwrite", @@ -176,7 +185,7 @@ def parse_arguments(): help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", ) - args = parser.parse_args() + args = parser.parse_args(argv) return args @@ -298,8 +307,8 @@ def export_onnx_models( return output_paths -def main(): - args = parse_arguments() +def main(argv=None): + args = parse_arguments(argv) setup_logger(args.verbose) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 5a0af8bd6058f..de7134eae505e 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -33,6 +33,11 @@ def chain_model(args): ] if args.use_forced_decoder_ids: beam_inputs.append("decoder_input_ids") + else: + beam_inputs.append("") + + if args.use_logits_processor: + beam_inputs.append("logits_processor") beam_outputs = ["sequences"] node = helper.make_node("BeamSearch", inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") @@ -74,6 +79,10 @@ def chain_model(args): ) graph_inputs.append(decoder_input_ids) + if args.use_logits_processor: + logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) + graph_inputs.append(logits_processor) + # graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py new file mode 100644 index 0000000000000..052c5ca264af9 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py @@ -0,0 +1,75 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import numpy as np +import pytest +import torch + +from onnxruntime import InferenceSession, SessionOptions + + +class TestTimestampProcessor(unittest.TestCase): + def generate_model(self, arguments: str): + from onnxruntime.transformers.models.whisper.convert_to_onnx import main as whisper_to_onnx + + whisper_to_onnx(arguments.split()) + + def generate_dataset(self): + from datasets import load_dataset + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained("openai/whisper-tiny") + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + input_features = inputs.input_features + return [input_features, processor] + + def run_timestamp(self, provider: str): + self.generate_model("-m openai/whisper-tiny --optimize_onnx --precision fp32 -l -e") + [input_features, processor] = self.generate_dataset() + model_path = "./onnx_models/openai/whisper-tiny_beamsearch.onnx" + sess_options = SessionOptions() + sess_options.log_severity_level = 4 + sess = InferenceSession(model_path, sess_options, providers=[provider]) + input_data = input_features.repeat(1, 1, 1) + ort_inputs = { + "input_features": np.float32(input_data.cpu().numpy()), + "max_length": np.array([128], dtype=np.int32), + "min_length": np.array([0], dtype=np.int32), + "num_beams": np.array([1], dtype=np.int32), + "num_return_sequences": np.array([1], dtype=np.int32), + "length_penalty": np.array([1.0], dtype=np.float32), + "repetition_penalty": np.array([1.0], dtype=np.float32), + "logits_processor": np.array([1], dtype=np.int32), + } + ort_out = sess.run(None, ort_inputs) + ort_out_tensor = torch.from_numpy(ort_out[0]) + ort_transcription = processor.batch_decode( + ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True + ) + expected_transcription = [ + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "offsets": [ + { + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", + "timestamp": (0.0, 5.44), + } + ], + } + ] + self.assertEqual(ort_transcription, expected_transcription) + + @pytest.mark.slow + def test_timestamp_cpu(self): + provider = "CPUExecutionProvider" + self.run_timestamp(provider) + + +if __name__ == "__main__": + unittest.main()