Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jan 10, 2024
1 parent 4fd372f commit 67aa3d8
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 3 deletions.
1 change: 1 addition & 0 deletions scripts/3dspeaker/export-onnx.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# Copyright 2023-2024 Xiaomi Corp. (authors: Fangjun Kuang)

import argparse
import json
Expand Down
32 changes: 32 additions & 0 deletions scripts/3dspeaker/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,20 @@ function install_3d_speaker() {
popd
}

function download_test_data() {
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_cn_16k.wav
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_cn_16k.wav
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_cn_16k.wav

wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_a_en_16k.wav
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker1_b_en_16k.wav
wget -q https://github.com/csukuangfj/sr-data/raw/main/test/3d-speaker/speaker2_a_en_16k.wav
}

install_3d_speaker

download_test_data

export PYTHONPATH=$PWD/3D-Speaker:$PYTHONPATH
export PYTHONPATH=$PWD/3D-Speaker/speakerlab/bin:$PYTHONPATH

Expand All @@ -28,4 +40,24 @@ speech_eres2net_large_sv_zh-cn_3dspeaker_16k
for model in ${models[@]}; do
echo "--------------------$model--------------------"
python3 ./export-onnx.py --model $model

python3 ./test-onnx.py \
--model ${model}.onnx \
--file1 ./speaker1_a_cn_16k.wav \
--file2 ./speaker1_b_cn_16k.wav

python3 ./test-onnx.py \
--model ${model}.onnx \
--file1 ./speaker1_a_cn_16k.wav \
--file2 ./speaker2_a_cn_16k.wav

python3 ./test-onnx.py \
--model ${model}.onnx \
--file1 ./speaker1_a_en_16k.wav \
--file2 ./speaker1_b_en_16k.wav

python3 ./test-onnx.py \
--model ${model}.onnx \
--file1 ./speaker1_a_en_16k.wav \
--file2 ./speaker2_a_en_16k.wav
done
173 changes: 173 additions & 0 deletions scripts/3dspeaker/test-onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#!/usr/bin/env python3
# Copyright 2023-2024 Xiaomi Corp. (authors: Fangjun Kuang)

"""
This script computes speaker similarity score in the range [0-1]
of two wave files using a speaker embedding model.
"""
import argparse
import wave
from pathlib import Path

import kaldi_native_fbank as knf
import numpy as np
import onnxruntime as ort
from numpy.linalg import norm


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
required=True,
help="Path to the input onnx model. Example value: model.onnx",
)

parser.add_argument(
"--file1",
type=str,
required=True,
help="Input wave 1",
)

parser.add_argument(
"--file2",
type=str,
required=True,
help="Input wave 2",
)

return parser.parse_args()


def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray:
"""
Args:
filename:
Path to a wave file, which must be of 16-bit and 16kHz.
expected_sample_rate:
Expected sample rate of the wave file.
Returns:
Return a 1-D float32 array containing audio samples. Each sample is in
the range [-1, 1].
"""
filename = str(filename)
with wave.open(filename) as f:
wave_file_sample_rate = f.getframerate()
assert wave_file_sample_rate == expected_sample_rate, (
wave_file_sample_rate,
expected_sample_rate,
)

num_channels = f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0]
samples_float32 = samples_int16.astype(np.float32)

samples_float32 = samples_float32 / 32768

return samples_float32


def compute_features(samples: np.ndarray, sample_rate: int) -> np.ndarray:
opts = knf.FbankOptions()
opts.frame_opts.dither = 0
opts.frame_opts.samp_freq = sample_rate
opts.frame_opts.snip_edges = True

opts.mel_opts.num_bins = 80
opts.mel_opts.debug_mel = False

fbank = knf.OnlineFbank(opts)
fbank.accept_waveform(sample_rate, samples)
fbank.input_finished()

features = []
for i in range(fbank.num_frames_ready):
f = fbank.get_frame(i)
features.append(f)
features = np.stack(features, axis=0)

return features


class OnnxModel:
def __init__(
self,
filename: str,
):
session_opts = ort.SessionOptions()
session_opts.inter_op_num_threads = 1
session_opts.intra_op_num_threads = 1

self.session_opts = session_opts

self.model = ort.InferenceSession(
filename,
sess_options=self.session_opts,
)

meta = self.model.get_modelmeta().custom_metadata_map
self.normalize_samples = int(meta["normalize_samples"])
self.sample_rate = int(meta["sample_rate"])
self.output_dim = int(meta["output_dim"])
self.feature_normalize_type = meta["feature_normalize_type"]

def __call__(self, x: np.ndarray) -> np.ndarray:
"""
Args:
x:
A 2-D float32 tensor of shape (T, C).
y:
A 1-D float32 tensor containing model output.
"""
x = np.expand_dims(x, axis=0)

return self.model.run(
[
self.model.get_outputs()[0].name,
],
{
self.model.get_inputs()[0].name: x,
},
)[0][0]


def main():
args = get_args()
print(args)
filename = Path(args.model)
file1 = Path(args.file1)
file2 = Path(args.file2)
assert filename.is_file(), filename
assert file1.is_file(), file1
assert file2.is_file(), file2

model = OnnxModel(filename)
wave1 = read_wavefile(file1, model.sample_rate)
wave2 = read_wavefile(file2, model.sample_rate)

if not model.normalize_samples:
wave1 = wave1 * 32768
wave2 = wave2 * 32768

features1 = compute_features(wave1, model.sample_rate)
features2 = compute_features(wave2, model.sample_rate)

if model.feature_normalize_type == "global-mean":
features1 -= features1.mean(axis=0, keepdims=True)
features2 -= features2.mean(axis=0, keepdims=True)

output1 = model(features1)
output2 = model(features2)

similarity = np.dot(output1, output2) / (norm(output1) * norm(output2))
print(f"similarity in the range [0-1]: {similarity}")


if __name__ == "__main__":
main()
4 changes: 1 addition & 3 deletions scripts/wespeaker/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""
This script computes speaker similarity score in the range [0-1]
of two wave files using a speaker recognition model.
of two wave files using a speaker embedding model.
"""
import argparse
import wave
Expand Down Expand Up @@ -159,8 +159,6 @@ def main():
output1 = model(features1)
output2 = model(features2)

print(output1.shape)
print(output2.shape)
similarity = np.dot(output1, output2) / (norm(output1) * norm(output2))
print(f"similarity in the range [0-1]: {similarity}")

Expand Down

0 comments on commit 67aa3d8

Please sign in to comment.