Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timestamp logits processor for whisper #15853

Merged
merged 23 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>

#### Inputs (5 - 11)
#### Inputs (5 - 12)

<dl>
<dt><tt>input_ids</tt> : F</dt>
Expand All @@ -453,6 +453,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
<dt><tt>decoder_input_ids</tt> (optional) : I</dt>
<dd>The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)</dd>
<dt><tt>logits_processor</tt> (optional) : I</dt>
<dd>Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)</dd>
</dl>

#### Outputs (1 - 3)
Expand Down
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**<br> *in* cur_tokens:**T**<br> *in* prev_suffix_match_idx:**T**<br> *in* pred_tokens:**T**<br> *out* tokens:**T**<br> *out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
Expand Down Expand Up @@ -790,7 +790,7 @@ Do not modify directly.*
| |
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**F**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *in* prefix_vocab_mask:**M**<br> *in* attention_mask:**I**<br> *in* decoder_input_ids:**I**<br> *in* logits_processor:**I**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasAdd|*in* X:**T**<br> *in* bias:**T**<br> *in* skip:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**<br> *in* bias:**T**<br> *in* residual:**T**<br> *in* ratio:**T1**<br> *in* training_mode:**T2**<br> *out* output:**T**<br> *out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)<br/> **T2** = tensor(bool)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
repetition_penalty = 1.0f;
}
ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty);

auto* logits_processor_tensor = context->Input<Tensor>(11);
logits_processor = logits_processor_tensor ? static_cast<int>(*logits_processor_tensor->Data<int32_t>()) : 0;
ORT_ENFORCE(logits_processor >= 0,
"logits_processor shall be a non-negative integer, got ", logits_processor);
}

void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ struct IGenerationParameters {
static constexpr int kModelTypeT5 = 1;
static constexpr int kModelTypeWhisper = 2;

static constexpr int kLogitsProcessorTypeWhisper = 1;

// Parameters from node attributes
int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5; 2 for float inputs like Whisper
int eos_token_id;
Expand All @@ -143,6 +145,7 @@ struct IGenerationParameters {
float repetition_penalty;
int batch_size; // deduce from first dimension of input_ids
int sequence_length; // deduce from second dimension of input_ids of GPT-2 or decoder_input_ids of T5
int logits_processor;

gsl::span<const int32_t> vocab_mask;
gsl::span<const int32_t> prefix_vocab_mask;
Expand Down
122 changes: 122 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,128 @@ void PresencePenaltyLogitsProcessor<T>::Process(const ISequences*,
#endif
}

template <typename T>
TimestampLogitsProcessor<T>::TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index)
: eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {}

template <typename T>
void TimestampLogitsProcessor<T>::Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) {
const int beg_token_id_ = eos_token_id_ + 107;
const int not_token_id_ = eos_token_id_ + 106;
const int solm_token_id_ = eos_token_id_ + 105;
const int sot_token_id_ = eos_token_id_ + 1;
constexpr int translate_token_id_ = 50358;
Copy link
Contributor

@tianleiwu tianleiwu May 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constant configuration (these special IDs) shall read from attributes

Copy link
Contributor Author

@stevenlix stevenlix May 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eos_token_id_ is from attribute. Tokens listed here have constant offset to eos_token_id_ and may not need to be provided explicitly in the attributes

constexpr int transcribe_token_id_ = 50359;

const int batch_beam_size = next_token_scores.batch_beam_size;
const int vocab_size = next_token_scores.vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<T> beam_token_scores = next_token_scores.GetScores(i);
gsl::span<const int32_t> sequence = sequences->GetSequence(i);
const size_t seq_length = sequence.size();

// Find first timestamp
size_t sample_begin = 0;
for (size_t j = 0; j < seq_length; j++) {
sample_begin++;
if (sequence[j] >= beg_token_id_) {
break;
}
}

// Suppress tokens
for (int j = 0; j < vocab_size; j++) {
// Suppress notimestamps and solm tokens
if (j == not_token_id_ || j == solm_token_id_) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}

// Suppress sot, translate and transcribe tokens
if (seq_length > sample_begin) {
if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}
}

// Timestamps should be in pair except the first one
const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_;
const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_;
Fixed Show fixed Hide fixed
if (last_was_timestamp) {
if (penultimate_was_timestamp) {
// If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated
for (int j = beg_token_id_; j < vocab_size; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
} else {
// If timestamp doesn't show up in pair, generate timestamp
for (int j = 0; j < eos_token_id_; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}
}

// Find timestamp tokens
std::vector<int32_t> timestamps;
for (const auto& word_id : sequence) {
if (word_id >= beg_token_id_) {
timestamps.push_back(word_id);
}
}

// Timestamps will not decrease
const size_t timestamps_len = timestamps.size();
if (timestamps_len > 0) {
int timestamp_last = 0;
if (last_was_timestamp && !penultimate_was_timestamp) {
// For single timestamp at the end, next timestamp must not be smaller
timestamp_last = timestamps.back();
} else {
// For paired timestamp at the end, next timestamp must be greater
timestamp_last = timestamps.back() + 1;
}

for (int j = beg_token_id_; j < timestamp_last; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}

if (seq_length == sample_begin) {
const int last_allowed = beg_token_id_ + max_initial_timestamp_index_;
for (int j = last_allowed + 1; j < vocab_size; j++) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}

// Caculate logsumexp on timestamps
float timestamp_logprob = std::numeric_limits<T>::lowest();
{
float logsumexp = 0.0f;
const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end());
for (int j = beg_token_id_; j < vocab_size; ++j) {
if (beam_token_scores[j] > std::numeric_limits<T>::lowest()) {
logsumexp += expf(beam_token_scores[j] - logprob_max);
}
}
if (logsumexp > 0.0f) {
timestamp_logprob = logf(logsumexp) + logprob_max;
}
}

const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_);
if (timestamp_logprob > max_text_token_logprob) {
for (int j = 0; j < beg_token_id_; ++j) {
beam_token_scores[j] = std::numeric_limits<T>::lowest();
}
}
}

#ifdef DEBUG_GENERATION
DumpScores("TimestampLogitsProcessor", next_token_scores);
#endif
}

void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
LogitsProcessorInitImpl<BeamSearchParameters>(parameters);
}
Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,19 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor<T> {
float presence_penalty_;
};

template <typename T>
class TimestampLogitsProcessor : public ILogitsProcessor<T> {
public:
TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index);

void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;

private:
int eos_token_id_;
int max_initial_timestamp_index_;
};

class LogitsProcessorList : public ILogitsProcessorList {
public:
LogitsProcessorList() = default;
Expand Down Expand Up @@ -193,6 +206,13 @@ class LogitsProcessorList : public ILogitsProcessorList {
processor_list_.push_back(presence_penalty_processor_.get());
}

// Add timestamp processor for whisper model
if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) {
constexpr int max_initial_timestamp_index = 50;
timestamp_processor_ = std::make_unique<TimestampLogitsProcessor<float>>(parameters.eos_token_id, max_initial_timestamp_index);
processor_list_.push_back(timestamp_processor_.get());
}

batch_beam_size_ = parameters.BatchBeamSize();
vocab_size_ = parameters.vocab_size;
}
Expand All @@ -208,6 +228,7 @@ class LogitsProcessorList : public ILogitsProcessorList {
std::unique_ptr<MinLengthLogitsProcessor<float>> min_length_processor_;
std::unique_ptr<TemperatureLogitsProcessor<float>> temperature_processor_;
std::unique_ptr<PresencePenaltyLogitsProcessor<float>> presence_penalty_processor_;
std::unique_ptr<TimestampLogitsProcessor<float>> timestamp_processor_;
};

} // namespace transformers
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
.Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional)
.Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional)
.Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
.Output(2, "scores",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
logger = logging.getLogger("")


def parse_arguments():
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()

pretrained_models = PRETRAINED_WHISPER_MODELS
Expand Down Expand Up @@ -98,6 +98,15 @@ def parse_arguments():
)
parser.set_defaults(use_forced_decoder_ids=False)

parser.add_argument(
"-l",
"--use_logits_processor",
required=False,
action="store_true",
help="Use logits_processor as an extra graph input to enable specific logits processing",
)
parser.set_defaults(use_specific_logits_processor=False)

parser.add_argument(
"-w",
"--overwrite",
Expand Down Expand Up @@ -176,7 +185,7 @@ def parse_arguments():
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
)

args = parser.parse_args()
args = parser.parse_args(argv)

return args

Expand Down Expand Up @@ -298,8 +307,8 @@ def export_onnx_models(
return output_paths


def main():
args = parse_arguments()
def main(argv=None):
args = parse_arguments(argv)

setup_logger(args.verbose)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def chain_model(args):
]
if args.use_forced_decoder_ids:
beam_inputs.append("decoder_input_ids")
else:
beam_inputs.append("")

if args.use_logits_processor:
beam_inputs.append("logits_processor")
beam_outputs = ["sequences"]

node = helper.make_node("BeamSearch", inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode")
Expand Down Expand Up @@ -74,6 +79,10 @@ def chain_model(args):
)
graph_inputs.append(decoder_input_ids)

if args.use_logits_processor:
logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
graph_inputs.append(logits_processor)

# graph outputs
sequences = helper.make_tensor_value_info(
"sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -------------------------------------------------------------------------
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import unittest

import numpy as np
import pytest
import torch

from onnxruntime import InferenceSession, SessionOptions


class TestTimestampProcessor(unittest.TestCase):
def generate_model(self, arguments: str):
from onnxruntime.transformers.models.whisper.convert_to_onnx import main as whisper_to_onnx

whisper_to_onnx(arguments.split())

def generate_dataset(self):
from datasets import load_dataset
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
input_features = inputs.input_features
return [input_features, processor]

def run_timestamp(self, provider: str):
self.generate_model("-m openai/whisper-tiny --optimize_onnx --precision fp32 -l -e")
[input_features, processor] = self.generate_dataset()
model_path = "./onnx_models/openai/whisper-tiny_beamsearch.onnx"
sess_options = SessionOptions()
sess_options.log_severity_level = 4
sess = InferenceSession(model_path, sess_options, providers=[provider])
input_data = input_features.repeat(1, 1, 1)
ort_inputs = {
"input_features": np.float32(input_data.cpu().numpy()),
"max_length": np.array([128], dtype=np.int32),
"min_length": np.array([0], dtype=np.int32),
"num_beams": np.array([1], dtype=np.int32),
"num_return_sequences": np.array([1], dtype=np.int32),
"length_penalty": np.array([1.0], dtype=np.float32),
"repetition_penalty": np.array([1.0], dtype=np.float32),
"logits_processor": np.array([1], dtype=np.int32),
}
ort_out = sess.run(None, ort_inputs)
ort_out_tensor = torch.from_numpy(ort_out[0])
ort_transcription = processor.batch_decode(
ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True
)
expected_transcription = [
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"offsets": [
{
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.",
"timestamp": (0.0, 5.44),
}
],
}
]
self.assertEqual(ort_transcription, expected_transcription)

@pytest.mark.slow
def test_timestamp_cpu(self):
provider = "CPUExecutionProvider"
self.run_timestamp(provider)


if __name__ == "__main__":
unittest.main()