Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert whisper models to onnx format #238

Merged
merged 22 commits into from
Aug 7, 2023

Conversation

csukuangfj
Copy link
Collaborator

@csukuangfj csukuangfj commented Aug 5, 2023

We are trying to support whisper models in sherpa-onnx for non-streaming speech recognition.

In the first step, we have managed to convert the model to onnx format and have successfully tested the exported onnx model in Python using greedy search.

TODOs

  • Modify kaldi-native-fbank to support features used by whisper
  • Add C++ implementation for whisper (implement greedy search first)
  • Video demos

@csukuangfj csukuangfj merged commit 45b9d4a into k2-fsa:master Aug 7, 2023
@csukuangfj csukuangfj deleted the convert-whisper branch August 7, 2023 04:34
@csukuangfj
Copy link
Collaborator Author

Please visit
https://huggingface.co/spaces/k2-fsa/automatic-speech-recognition
to try whisper models within your browser

Screenshot 2023-08-07 at 14 18 08

@jackwenshann
Copy link

你好,如何将自己重官方下载的pytorch_model.bin转换成onnx格式呢。

@csukuangfj
Copy link
Collaborator Author

你好,如何将自己重官方下载的pytorch_model.bin转换成onnx格式呢。

https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/export-onnx.html
这个是文档。里面有具体的步骤和代码链接

@jackwenshann
Copy link

好的,我试一下

@jackwenshann
Copy link

你好,如何将自己重官方下载的pytorch_model.bin转换成onnx格式呢。

https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/export-onnx.html 这个是文档。里面有具体的步骤和代码链接

你好,这个脚本是自动在官网下载模型转换的。目前我这边有个模型是基于tiny.en进行微调训练的,想直接将训练出来的bin文件转换成对应的onnx格式,有什么好的方法吗。是否需要先转换成.pt格式,非常期待你的回答,谢谢

@csukuangfj
Copy link
Collaborator Author

你好,如何将自己重官方下载的pytorch_model.bin转换成onnx格式呢。

https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/export-onnx.html 这个是文档。里面有具体的步骤和代码链接

你好,这个脚本是自动在官网下载模型转换的。目前我这边有个模型是基于tiny.en进行微调训练的,想直接将训练出来的bin文件转换成对应的onnx格式,有什么好的方法吗。是否需要先转换成.pt格式,非常期待你的回答,谢谢

.bin 或者 .pt 只是后缀不同? 官网提供的 .bin, 也是用 torch.save() 存储的.

你的 模型文件,只要和官方的一样,就可以用我们的脚本导出成 onnx.

你运行的时候, 有碰到什么问题么

@jackwenshann
Copy link

你好,如何将自己重官方下载的pytorch_model.bin转换成onnx格式呢。

https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/export-onnx.html 这个是文档。里面有具体的步骤和代码链接

你好,这个脚本是自动在官网下载模型转换的。目前我这边有个模型是基于tiny.en进行微调训练的,想直接将训练出来的bin文件转换成对应的onnx格式,有什么好的方法吗。是否需要先转换成.pt格式,非常期待你的回答,谢谢

.bin 或者 .pt 只是后缀不同? 官网提供的 .bin, 也是用 torch.save() 存储的.

你的 模型文件,只要和官方的一样,就可以用我们的脚本导出成 onnx.

你运行的时候, 有碰到什么问题么

目前只要执行你那边的提供的脚本,会自动取下载官方的模型,强行修改脚本指定我们的模型会提示
dims = ModelDimensions(**checkpoint["dims"])
KeyError: 'dims'错误

@csukuangfj
Copy link
Collaborator Author

你的 checkpoint 是如何保存的?

我上面说,你要和官方的完全兼容。

我建议你看下官方提供的 bin 里面,state_dict 包含什么

@ziggy1209
Copy link

你好,如何将自己重官方下载的pytorch_model.bin转换成onnx格式呢。

https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/export-onnx.html 这个是文档。里面有具体的步骤和代码链接

你好,这个脚本是自动在官网下载模型转换的。目前我这边有个模型是基于tiny.en进行微调训练的,想直接将训练出来的bin文件转换成对应的onnx格式,有什么好的方法吗。是否需要先转换成.pt格式,非常期待你的回答,谢谢

.bin 或者 .pt 只是后缀不同? 官网提供的 .bin, 也是用 torch.save() 存储的.
你的 模型文件,只要和官方的一样,就可以用我们的脚本导出成 onnx.
你运行的时候, 有碰到什么问题么

目前只要执行你那边的提供的脚本,会自动取下载官方的模型,强行修改脚本指定我们的模型会提示 dims = ModelDimensions(**checkpoint["dims"]) KeyError: 'dims'错误

Hi there!
Just encountered the same problem while exporting a model distilled by myself to onnx. Looks like there's a mismatch between the model compatible with whisper.load_model() and the model saved by the distillation script.
The keys of a model compatible with whisper.load_model() are ['dims','model_state_dict'], whereas the distilled model is comprised of its state dict ONLY.
Guess a manual update of the keys would do the trick?

@csukuangfj
Copy link
Collaborator Author

please see

wget -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin

you need to find out how original-model.bin is generated.

@ziggy1209
Copy link

please see

wget -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin

you need to find out how original-model.bin is generated.

This model can be safely loaded by whisper.load_model(), but any student model generated by create_student_model.py does have a mismatch in keys with the previous one.

@csukuangfj
Copy link
Collaborator Author

Here is the data contained in the original_model.bin and also for models from whisper:

(Pdb) p checkpoint.keys()
dict_keys(['dims', 'model_state_dict'])
(Pdb) p checkpoint['dims']
{'n_mels': 80, 'n_vocab': 51865, 'n_audio_ctx': 1500, 'n_audio_state': 384, 'n_audio_head': 6, 'n_audio_layer': 4, 'n_text_ctx': 448, 'n_text_state': 384, 'n_text_head': 6, 'n_text_layer': 4}

(Pdb) p checkpoint['model_state_dict'].keys()
dict_keys(['decoder.positional_embedding', 'encoder.positional_embedding', 'decoder.token_embedding.weight', 'decoder.blocks.0.mlp_ln.weight', 'decoder.blocks.0.mlp_ln.bias', 'decoder.blocks.0.mlp.0.weight', 'decoder.blocks.0.mlp.0.bias', 'decoder.blocks.0.mlp.2.weight', 'decoder.blocks.0.mlp.2.bias', 'decoder.blocks.0.attn_ln.weight', 'decoder.blocks.0.attn_ln.bias', 'decoder.blocks.0.attn.query.weight', 'decoder.blocks.0.attn.query.bias', 'decoder.blocks.0.attn.key.weight', 'decoder.blocks.0.attn.value.weight', 'decoder.blocks.0.attn.value.bias', 'decoder.blocks.0.attn.out.weight', 'decoder.blocks.0.attn.out.bias', 'decoder.blocks.0.cross_attn_ln.weight', 'decoder.blocks.0.cross_attn_ln.bias', 'decoder.blocks.0.cross_attn.query.weight', 'decoder.blocks.0.cross_attn.query.bias', 'decoder.blocks.0.cross_attn.key.weight', 'decoder.blocks.0.cross_attn.value.weight', 'decoder.blocks.0.cross_attn.value.bias', 'decoder.blocks.0.cross_attn.out.weight', 'decoder.blocks.0.cross_attn.out.bias', 'decoder.blocks.1.mlp_ln.weight', 'decoder.blocks.1.mlp_ln.bias', 'decoder.blocks.1.mlp.0.weight', 'decoder.blocks.1.mlp.0.bias', 'decoder.blocks.1.mlp.2.weight', 'decoder.blocks.1.mlp.2.bias', 'decoder.blocks.1.attn_ln.weight', 'decoder.blocks.1.attn_ln.bias', 'decoder.blocks.1.attn.query.weight', 'decoder.blocks.1.attn.query.bias', 'decoder.blocks.1.attn.key.weight', 'decoder.blocks.1.attn.value.weight', 'decoder.blocks.1.attn.value.bias', 'decoder.blocks.1.attn.out.weight', 'decoder.blocks.1.attn.out.bias', 'decoder.blocks.1.cross_attn_ln.weight', 'decoder.blocks.1.cross_attn_ln.bias', 'decoder.blocks.1.cross_attn.query.weight', 'decoder.blocks.1.cross_attn.query.bias', 'decoder.blocks.1.cross_attn.key.weight', 'decoder.blocks.1.cross_attn.value.weight', 'decoder.blocks.1.cross_attn.value.bias', 'decoder.blocks.1.cross_attn.out.weight', 'decoder.blocks.1.cross_attn.out.bias', 'decoder.blocks.2.mlp_ln.weight', 'decoder.blocks.2.mlp_ln.bias', 'decoder.blocks.2.mlp.0.weight', 'decoder.blocks.2.mlp.0.bias', 'decoder.blocks.2.mlp.2.weight', 'decoder.blocks.2.mlp.2.bias', 'decoder.blocks.2.attn_ln.weight', 'decoder.blocks.2.attn_ln.bias', 'decoder.blocks.2.attn.query.weight', 'decoder.blocks.2.attn.query.bias', 'decoder.blocks.2.attn.key.weight', 'decoder.blocks.2.attn.value.weight', 'decoder.blocks.2.attn.value.bias', 'decoder.blocks.2.attn.out.weight', 'decoder.blocks.2.attn.out.bias', 'decoder.blocks.2.cross_attn_ln.weight', 'decoder.blocks.2.cross_attn_ln.bias', 'decoder.blocks.2.cross_attn.query.weight', 'decoder.blocks.2.cross_attn.query.bias', 'decoder.blocks.2.cross_attn.key.weight', 'decoder.blocks.2.cross_attn.value.weight', 'decoder.blocks.2.cross_attn.value.bias', 'decoder.blocks.2.cross_attn.out.weight', 'decoder.blocks.2.cross_attn.out.bias', 'decoder.blocks.3.mlp_ln.weight', 'decoder.blocks.3.mlp_ln.bias', 'decoder.blocks.3.mlp.0.weight', 'decoder.blocks.3.mlp.0.bias', 'decoder.blocks.3.mlp.2.weight', 'decoder.blocks.3.mlp.2.bias', 'decoder.blocks.3.attn_ln.weight', 'decoder.blocks.3.attn_ln.bias', 'decoder.blocks.3.attn.query.weight', 'decoder.blocks.3.attn.query.bias', 'decoder.blocks.3.attn.key.weight', 'decoder.blocks.3.attn.value.weight', 'decoder.blocks.3.attn.value.bias', 'decoder.blocks.3.attn.out.weight', 'decoder.blocks.3.attn.out.bias', 'decoder.blocks.3.cross_attn_ln.weight', 'decoder.blocks.3.cross_attn_ln.bias', 'decoder.blocks.3.cross_attn.query.weight', 'decoder.blocks.3.cross_attn.query.bias', 'decoder.blocks.3.cross_attn.key.weight', 'decoder.blocks.3.cross_attn.value.weight', 'decoder.blocks.3.cross_attn.value.bias', 'decoder.blocks.3.cross_attn.out.weight', 'decoder.blocks.3.cross_attn.out.bias', 'decoder.ln.weight', 'decoder.ln.bias', 'encoder.conv1.weight', 'encoder.conv1.bias', 'encoder.conv2.weight', 'encoder.conv2.bias', 'encoder.blocks.0.mlp_ln.weight', 'encoder.blocks.0.mlp_ln.bias', 'encoder.blocks.0.mlp.0.weight', 'encoder.blocks.0.mlp.0.bias', 'encoder.blocks.0.mlp.2.weight', 'encoder.blocks.0.mlp.2.bias', 'encoder.blocks.0.attn_ln.weight', 'encoder.blocks.0.attn_ln.bias', 'encoder.blocks.0.attn.query.weight', 'encoder.blocks.0.attn.query.bias', 'encoder.blocks.0.attn.key.weight', 'encoder.blocks.0.attn.value.weight', 'encoder.blocks.0.attn.value.bias', 'encoder.blocks.0.attn.out.weight', 'encoder.blocks.0.attn.out.bias', 'encoder.blocks.1.mlp_ln.weight', 'encoder.blocks.1.mlp_ln.bias', 'encoder.blocks.1.mlp.0.weight', 'encoder.blocks.1.mlp.0.bias', 'encoder.blocks.1.mlp.2.weight', 'encoder.blocks.1.mlp.2.bias', 'encoder.blocks.1.attn_ln.weight', 'encoder.blocks.1.attn_ln.bias', 'encoder.blocks.1.attn.query.weight', 'encoder.blocks.1.attn.query.bias', 'encoder.blocks.1.attn.key.weight', 'encoder.blocks.1.attn.value.weight', 'encoder.blocks.1.attn.value.bias', 'encoder.blocks.1.attn.out.weight', 'encoder.blocks.1.attn.out.bias', 'encoder.blocks.2.mlp_ln.weight', 'encoder.blocks.2.mlp_ln.bias', 'encoder.blocks.2.mlp.0.weight', 'encoder.blocks.2.mlp.0.bias', 'encoder.blocks.2.mlp.2.weight', 'encoder.blocks.2.mlp.2.bias', 'encoder.blocks.2.attn_ln.weight', 'encoder.blocks.2.attn_ln.bias', 'encoder.blocks.2.attn.query.weight', 'encoder.blocks.2.attn.query.bias', 'encoder.blocks.2.attn.key.weight', 'encoder.blocks.2.attn.value.weight', 'encoder.blocks.2.attn.value.bias', 'encoder.blocks.2.attn.out.weight', 'encoder.blocks.2.attn.out.bias', 'encoder.blocks.3.mlp_ln.weight', 'encoder.blocks.3.mlp_ln.bias', 'encoder.blocks.3.mlp.0.weight', 'encoder.blocks.3.mlp.0.bias', 'encoder.blocks.3.mlp.2.weight', 'encoder.blocks.3.mlp.2.bias', 'encoder.blocks.3.attn_ln.weight', 'encoder.blocks.3.attn_ln.bias', 'encoder.blocks.3.attn.query.weight', 'encoder.blocks.3.attn.query.bias', 'encoder.blocks.3.attn.key.weight', 'encoder.blocks.3.attn.value.weight', 'encoder.blocks.3.attn.value.bias', 'encoder.blocks.3.attn.out.weight', 'encoder.blocks.3.attn.out.bias', 'encoder.ln_post.weight', 'encoder.ln_post.bias'])
(Pdb)

Please make sure your saved model has a structure like the above dict.

XiaYucca pushed a commit to XiaYucca/sherpa-onnx that referenced this pull request Jan 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants