Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tools] refine ctc alignment #1966

Merged
merged 6 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions tools/alignment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,44 @@ stage=0 # start from 0 if you need to start from data preparation
stop_stage=0

nj=16
feat_dir=raw_wav
dict=data/dict/lang_char.txt

dir=exp/
config=$dir/train.yaml
checkpoint=
checkpoint=/home/diwu/github/latest/wenet/examples/aishell/s0/exp/transformer/avg_20.pt
config=/home/diwu/github/latest/wenet/examples/aishell/s0/exp/transformer/train.yaml
set=
ali_format=$feat_dir/$set/format.data
ali_format=format.data
ali_result=$dir/ali

# model trained with trim tail will get a better alignment result
# (Todo) cif/attention/rnnt alignment
checkpoint=$dir/final.pt

set=test
ali_format=ali_format.data
ali_result=ali.res
blank_thres=0.9999
thres=0.00001
. tools/parse_options.sh || exit 1;

if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
nj=32
# Prepare required data for ctc alignment
echo "Prepare data, prepare required format"
for x in $set; do
tools/format_data.sh --nj ${nj} \
--feat-type wav --feat $feat_dir/$x/wav.scp \
$feat_dir/$x ${dict} > $feat_dir/$x/format.data.tmp

tools/make_raw_list.py data/$x/wav.scp data/$x/text \
ali_format
done
fi


if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
# Test model, please specify the model you want to use by --checkpoint
python wenet/bin/alignment_deprecated.py --gpu -1 \
--config $config \
--input_file $ali_format \
--checkpoint $checkpoint \
--batch_size 1 \
--dict $dict \
--result_file $ali_result \
mkdir -p exp_${thres}
python wenet/bin/alignment.py --gpu -1 \
--config $config \
--input_file $ali_format \
--checkpoint $checkpoint \
--batch_size 1 \
--dict $dict \
--result_file $ali_result \
--thres $thres \
--blank_thres $blank_thres \
--gen_praat

fi

Expand Down
53 changes: 42 additions & 11 deletions wenet/bin/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import yaml
from torch.utils.data import DataLoader
from textgrid import TextGrid, IntervalTier
import math

from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import load_checkpoint
Expand Down Expand Up @@ -52,13 +53,17 @@ def generator_textgrid(maxtime, lines, output):
tg.write(output)


def get_frames_timestamp(alignment):
def get_frames_timestamp(alignment,
prob,
blank_thres=0.999,
thres=0.0000000001):
# convert alignment to a praat format, which is a doing phonetics
# by computer and helps analyzing alignment
timestamp = []
# get frames level duration for each token
start = 0
end = 0
local_start = 0
while end < len(alignment):
while end < len(alignment) and alignment[end] == 0:
end += 1
Expand All @@ -68,26 +73,44 @@ def get_frames_timestamp(alignment):
end += 1
while end < len(alignment) and alignment[end - 1] == alignment[end]:
end += 1
timestamp.append(alignment[start:end])
local_start = end - 1
# find the possible front border for current token
while local_start >= start and (prob[local_start][0] < math.log(
blank_thres) or prob[local_start][alignment[
end - 1]] > math.log(thres)):
alignment[local_start] = alignment[end - 1]
local_start -= 1
cur_alignment = alignment[start:end]
timestamp.append(cur_alignment)
start = end
return timestamp


def get_labformat(timestamp, subsample):
begin = 0
begin_time = 0
duration = 0
labformat = []
for idx, t in enumerate(timestamp):
# 25ms frame_length,10ms hop_length, 1/subsample
subsample = get_subsample(configs)
# time duration
duration = len(t) * 0.01 * subsample
i = 0
while t[i] == 0:
i += 1
begin = i
dur = 0
while i < len(t) and t[i] != 0:
i += 1
dur += 1
begin = begin_time + begin * 0.01 * subsample
duration = dur * 0.01 * subsample
if idx < len(timestamp) - 1:
print("{:.2f} {:.2f} {}".format(begin, begin + duration,
char_dict[t[-1]]))
labformat.append("{:.2f} {:.2f} {}\n".format(
begin, begin + duration, char_dict[t[-1]]))
else:
else: # last token
non_blank = 0
for i in t:
if i != 0:
Expand All @@ -97,7 +120,7 @@ def get_labformat(timestamp, subsample):
char_dict[token]))
labformat.append("{:.2f} {:.2f} {}\n".format(
begin, begin + duration, char_dict[token]))
begin = begin + duration
begin_time += len(t) * 0.01 * subsample
return labformat


Expand All @@ -114,10 +137,19 @@ def get_labformat(timestamp, subsample):
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--blank_thres',
default=0.999999,
type=float,
help='ctc blank thes')
parser.add_argument('--thres',
default=0.000001,
type=float,
help='ctc non blank thes')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--dict', required=True, help='dict file')
parser.add_argument('--non_lang_syms',
help="non-linguistic symbol file. One symbol per line.")
parser.add_argument(
'--non_lang_syms',
help="non-linguistic symbol file. One symbol per line.")
parser.add_argument('--result_file',
required=True,
help='alignment result file')
Expand Down Expand Up @@ -165,6 +197,7 @@ def get_labformat(timestamp, subsample):
ali_conf['filter_conf']['min_output_input_ratio'] = 0
ali_conf['speed_perturb'] = False
ali_conf['spec_aug'] = False
ali_conf['spec_trim'] = False
ali_conf['shuffle'] = False
ali_conf['sort'] = False
ali_conf['fbank_conf']['dither'] = 0.0
Expand Down Expand Up @@ -196,7 +229,6 @@ def get_labformat(timestamp, subsample):
for batch_idx, batch in enumerate(ali_data_loader):
print("#" * 80)
key, feat, target, feats_length, target_length = batch
print(key)

feat = feat.to(device)
target = target.to(device)
Expand All @@ -213,12 +245,11 @@ def get_labformat(timestamp, subsample):
ctc_probs = ctc_probs.squeeze(0)
target = target.squeeze(0)
alignment = forced_align(ctc_probs, target)
print(alignment)
fout.write('{} {}\n'.format(key[0], alignment))

if args.gen_praat:
timestamp = get_frames_timestamp(alignment)
print(timestamp)
timestamp = get_frames_timestamp(alignment, ctc_probs,
args.blank_thres, args.thres)
subsample = get_subsample(configs)
labformat = get_labformat(timestamp, subsample)

Expand Down