-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add timestamp logits processor for whisper #15853
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
01ad1da
add timestamp processor for whisper
stevenlix 9a3e611
Update logits_processor.cc
stevenlix 90eff04
Update logits_processor.h
stevenlix 01fb75b
Update logits_processor.cc
stevenlix 29f331c
add new input in beamsearch contrib op to enable timestamp processing
stevenlix 9ef5d4a
resolve conflict
stevenlix 86339ab
update document
stevenlix 5be695b
fix format
stevenlix cd192db
add test
stevenlix 58a9538
Update test_whisper_timestamp_processor.py
stevenlix bb5113a
fix issues
stevenlix 22610f2
Merge branch 'timestamp' of https://github.com/microsoft/onnxruntime …
stevenlix 72e90ca
format code
stevenlix c207b16
Merge branch 'main' into timestamp
stevenlix bd6f079
rename timestamp_enable to logits_processor
stevenlix 7d4e2d7
resolve conflicts
stevenlix 7048cdd
fix issue
stevenlix 835ba50
update docs
stevenlix dffcfeb
make logits_processor input as optional
stevenlix 8ff2600
formatting
stevenlix 33ec665
formatting
stevenlix c2f1886
Merge branch 'main' into timestamp
stevenlix 2ac6d9e
update docs
stevenlix File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# ------------------------------------------------------------------------- | ||
|
||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
|
||
from onnxruntime import InferenceSession, SessionOptions | ||
|
||
|
||
class TestTimestampProcessor(unittest.TestCase): | ||
def generate_model(self, arguments: str): | ||
from onnxruntime.transformers.models.whisper.convert_to_onnx import main as whisper_to_onnx | ||
|
||
whisper_to_onnx(arguments.split()) | ||
|
||
def generate_dataset(self): | ||
from datasets import load_dataset | ||
from transformers import AutoProcessor | ||
|
||
processor = AutoProcessor.from_pretrained("openai/whisper-tiny") | ||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") | ||
inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") | ||
input_features = inputs.input_features | ||
return [input_features, processor] | ||
|
||
def run_timestamp(self, provider: str): | ||
self.generate_model("-m openai/whisper-tiny --optimize_onnx --precision fp32 -l -e") | ||
[input_features, processor] = self.generate_dataset() | ||
model_path = "./onnx_models/openai/whisper-tiny_beamsearch.onnx" | ||
sess_options = SessionOptions() | ||
sess_options.log_severity_level = 4 | ||
sess = InferenceSession(model_path, sess_options, providers=[provider]) | ||
input_data = input_features.repeat(1, 1, 1) | ||
ort_inputs = { | ||
"input_features": np.float32(input_data.cpu().numpy()), | ||
"max_length": np.array([128], dtype=np.int32), | ||
"min_length": np.array([0], dtype=np.int32), | ||
"num_beams": np.array([1], dtype=np.int32), | ||
"num_return_sequences": np.array([1], dtype=np.int32), | ||
"length_penalty": np.array([1.0], dtype=np.float32), | ||
"repetition_penalty": np.array([1.0], dtype=np.float32), | ||
"logits_processor": np.array([1], dtype=np.int32), | ||
} | ||
ort_out = sess.run(None, ort_inputs) | ||
ort_out_tensor = torch.from_numpy(ort_out[0]) | ||
ort_transcription = processor.batch_decode( | ||
ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True | ||
) | ||
expected_transcription = [ | ||
{ | ||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", | ||
"offsets": [ | ||
{ | ||
"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", | ||
"timestamp": (0.0, 5.44), | ||
} | ||
], | ||
} | ||
] | ||
self.assertEqual(ort_transcription, expected_transcription) | ||
|
||
@pytest.mark.slow | ||
def test_timestamp_cpu(self): | ||
provider = "CPUExecutionProvider" | ||
self.run_timestamp(provider) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These constant configuration (these special IDs) shall read from attributes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eos_token_id_ is from attribute. Tokens listed here have constant offset to eos_token_id_ and may not need to be provided explicitly in the attributes