Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support distil-small.en whisper #472

Merged
merged 7 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/scripts/test-offline-whisper.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ tiny
base
small
medium
distil-medium.en
distil-small.en
)

for name in ${names[@]}; do
Expand Down
59 changes: 44 additions & 15 deletions .github/workflows/export-whisper-to-onnx.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
model: ["distil-medium.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2"]
os: [macos-latest]
# model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large", "large-v1", "large-v2", "distil-large-v2"]
model: ["distil-medium.en", "distil-small.en", "tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium"]
python-version: ["3.8"]

steps:
Expand All @@ -42,45 +43,66 @@ jobs:
if [[ $model == distil-medium.en ]]; then
wget -q -O distil-medium-en-original-model.bin https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/original-model.bin
ls -lh
elif [[ $model == distil-large-v2 ]]; then
wget -q -O distil-large-v2-original-model.bin https://huggingface.co/distil-whisper/distil-large-v2/resolve/main/original-model.bin
ls -lh
elif [[ $model == distil-small.en ]]; then
wget -q -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
ls -lh
fi
python3 ./export-onnx.py --model ${{ matrix.model }}
# python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./

ls -lh

if [[ $model != distil-medium.en ]]; then
ls -lh ~/.cache/whisper
fi
ls -lh ~/.cache/whisper || true
ls -lh distil*original-model.bin || true
rm -rf ~/.cache/whisper
rm -f distil*original-model.bin

src=sherpa-onnx-whisper-${{ matrix.model }}

mkdir $src
cp *.onnx $src/
cp *tokens.txt $src
cd ..
mv whisper $src

echo "------------------------------"

cd $src
du -h -d1 .
ls -lh
mkdir -p test_wavs
cd test_wavs
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/0.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/1.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/8k.wav
wget -q https://huggingface.co/csukuangfj/sherpa-onnx-whisper-medium.en/resolve/main/test_wavs/trans.txt
cd ../..
mv $src ../..
mv $src ../
echo "pwd: $PWD"

cd ../..
cd ../
echo "--------------------"
ls -lh
ls -lh $src
echo "--------------------"

tar cjvf ./$src.tar.bz2 $src
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
#tar cvjf - $src | split --bytes=1024MB - $src.tar.bz2.
tar cvjf $src.tar.bz2 $src
split -b 1G $src.tar.bz2 $src.tar.bz2.
rm $src.tar.bz2
# cat $src.tar.gz.* | tar xjf -
else
tar cvjf $src.tar.bz2 $src
fi
ls -lh


- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
file: ./*.tar*
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
Expand All @@ -99,14 +121,21 @@ jobs:
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
rm -rf huggingface/*

cp -av $src/* ./huggingface/
if [[ $model == large || $model == large-v1 || $model == large-v2 || $model == distil-large-v2 ]]; then
mv $src.tar* ./huggingface
else
cp -v $src/*.onnx ./huggingface
cp -v $src/*tokens* ./huggingface
cp -av $src/test_wavs ./huggingface
fi

cd huggingface

git status
ls -lh
git lfs track "*.onnx"
# git lfs track "*.ort"
git lfs track "*gz*"
git lfs track "*onnx*"

git add .
git commit -m "upload ${{ matrix.model }}"
git push https://csukuangfj:[email protected]/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
4 changes: 2 additions & 2 deletions .github/workflows/test-python-offline-websocket-server.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:
./sherpa-onnx-zipformer-en-2023-06-26/test_wavs/8k.wav

- name: Start server for paraformer models
if: matrix.model_type == 'paraformer'
if: matrix.model_type == 'paraformer' && matrix.os != 'windows-latest'
shell: bash
run: |
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-bilingual-zh-en
Expand All @@ -106,7 +106,7 @@ jobs:
sleep 10

- name: Start client for paraformer models
if: matrix.model_type == 'paraformer'
if: matrix.model_type == 'paraformer' && matrix.os != 'windows-latest'
shell: bash
run: |
python3 ./python-api-examples/offline-websocket-client-decode-files-paralell.py \
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)

set(SHERPA_ONNX_VERSION "1.9.0")
set(SHERPA_ONNX_VERSION "1.9.1")

# Disable warning about
#
Expand Down
28 changes: 27 additions & 1 deletion scripts/whisper/export-onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_args():
"tiny", "tiny.en", "base", "base.en",
"small", "small.en", "medium", "medium.en",
"large", "large-v1", "large-v2",
"distil-medium.en",
"distil-medium.en", "distil-small.en", "distil-large-v2"
],
# fmt: on
)
Expand Down Expand Up @@ -314,6 +314,32 @@ def main():
"""
)
model = whisper.load_model(filename)
elif name == "distil-large-v2":
filename = "./distil-large-v2-original-model.bin"
if not Path(filename).is_file():
raise ValueError(
"""
Please go to https://huggingface.co/distil-whisper/distil-large-v2
to download original-model.bin
You can use the following command to do that:

wget -O distil-large-v2-original-model.bin https://huggingface.co/distil-whisper/distil-large-v2/resolve/main/original-model.bin
"""
)
model = whisper.load_model(filename)
elif name == "distil-small.en":
filename = "./distil-small-en-original-model.bin"
if not Path(filename).is_file():
raise ValueError(
"""
Please go to https://huggingface.co/distil-whisper/distil-small.en
to download original-model.bin
You can use the following command to do that:

wget -O distil-small-en-original-model.bin https://huggingface.co/distil-whisper/distil-small.en/resolve/main/original-model.bin
"""
)
model = whisper.load_model(filename)
else:
model = whisper.load_model(name)
print(model.dims)
Expand Down
6 changes: 4 additions & 2 deletions scripts/whisper/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def detect_language(
logits = logits.reshape(-1)
mask = torch.ones(logits.shape[0], dtype=torch.int64)
mask[self.all_language_tokens] = 0
logits[mask] = float("-inf")
logits[mask != 0] = float("-inf")
lang_id = logits.argmax().item()
print("detected language: ", self.id2lang[lang_id])
return lang_id
Expand Down Expand Up @@ -263,7 +263,9 @@ def compute_features(filename: str) -> torch.Tensor:

target = 3000
if mel.shape[0] > target:
mel = mel[:target]
# -50 so that there are some zero tail paddings.
mel = mel[: target - 50]
mel = torch.nn.functional.pad(mel, (0, 0, 0, 50), "constant", 0)

# We don't need to pad it to 30 seconds now!
# mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
Expand Down
7 changes: 4 additions & 3 deletions sherpa-onnx/csrc/offline-recognizer-whisper-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,12 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;

if (num_frames > max_num_frames) {
// we use 50 here so that there will be some zero tail paddings
if (num_frames >= max_num_frames - 50) {
SHERPA_ONNX_LOGE(
"Only waves less than 30 seconds are supported. We process only the "
"first 30 seconds and discard the remaining data");
num_frames = max_num_frames;
num_frames = max_num_frames - 50;
}

NormalizeFeatures(f.data(), num_frames, feat_dim);
Expand Down Expand Up @@ -140,7 +141,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
Ort::Value mel = Ort::Value::CreateTensor<float>(
model_->Allocator(), shape.data(), shape.size());
float *p_mel = mel.GetTensorMutableData<float>();
std::copy(f.begin(), f.end(), p_mel);
std::copy(f.data(), f.data() + actual_frames * feat_dim, p_mel);

memset(p_mel + f.size(), 0,
(actual_frames - num_frames) * feat_dim * sizeof(float));
Expand Down
Loading