From 773243992ed81144e9d3e119715ec2b482116b45 Mon Sep 17 00:00:00 2001 From: "di.wu" Date: Fri, 25 Aug 2023 17:44:15 +0800 Subject: [PATCH 1/6] [tools] refine ctc alignment --- tools/alignment.sh | 43 +++++++++++++++++----------------- wenet/bin/alignment.py | 53 +++++++++++++++++++++++++++++++++--------- 2 files changed, 64 insertions(+), 32 deletions(-) diff --git a/tools/alignment.sh b/tools/alignment.sh index 64d860bb6..3f5ac4314 100644 --- a/tools/alignment.sh +++ b/tools/alignment.sh @@ -7,42 +7,43 @@ stage=0 # start from 0 if you need to start from data preparation stop_stage=0 nj=16 -feat_dir=raw_wav dict=data/dict/lang_char.txt dir=exp/ config=$dir/train.yaml -checkpoint= -checkpoint=/home/diwu/github/latest/wenet/examples/aishell/s0/exp/transformer/avg_20.pt -config=/home/diwu/github/latest/wenet/examples/aishell/s0/exp/transformer/train.yaml -set= -ali_format=$feat_dir/$set/format.data -ali_format=format.data -ali_result=$dir/ali - +# model trained with trim tail will get a better alignment result +checkpoint=$dir/final.pt + +set=test +ali_format=ali_format.data +ali_result=ali.res +blank_thres=0.9999 +thres=0.00001 . tools/parse_options.sh || exit 1; if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then - nj=32 # Prepare required data for ctc alignment echo "Prepare data, prepare required format" for x in $set; do - tools/format_data.sh --nj ${nj} \ - --feat-type wav --feat $feat_dir/$x/wav.scp \ - $feat_dir/$x ${dict} > $feat_dir/$x/format.data.tmp - + tools/make_raw_list.py data/$x/wav.scp data/$x/text \ + ali_format done fi + if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then # Test model, please specify the model you want to use by --checkpoint - python wenet/bin/alignment_deprecated.py --gpu -1 \ - --config $config \ - --input_file $ali_format \ - --checkpoint $checkpoint \ - --batch_size 1 \ - --dict $dict \ - --result_file $ali_result \ + mkdir -p exp_${thres} + python wenet/bin/alignment.py --gpu -1 \ + --config $config \ + --input_file $ali_format \ + --checkpoint $checkpoint \ + --batch_size 1 \ + --dict $dict \ + --result_file $ali_result \ + --thres $thres \ + --blank_thres $blank_thres \ + --gen_praat fi diff --git a/wenet/bin/alignment.py b/wenet/bin/alignment.py index 071691183..44e13b6d1 100644 --- a/wenet/bin/alignment.py +++ b/wenet/bin/alignment.py @@ -25,6 +25,7 @@ import yaml from torch.utils.data import DataLoader from textgrid import TextGrid, IntervalTier +import math from wenet.dataset.dataset import Dataset from wenet.utils.checkpoint import load_checkpoint @@ -52,13 +53,17 @@ def generator_textgrid(maxtime, lines, output): tg.write(output) -def get_frames_timestamp(alignment): +def get_frames_timestamp(alignment, + prob, + blank_thres=0.999, + thres=0.0000000001): # convert alignment to a praat format, which is a doing phonetics # by computer and helps analyzing alignment timestamp = [] # get frames level duration for each token start = 0 end = 0 + local_start = 0 while end < len(alignment): while end < len(alignment) and alignment[end] == 0: end += 1 @@ -68,26 +73,44 @@ def get_frames_timestamp(alignment): end += 1 while end < len(alignment) and alignment[end - 1] == alignment[end]: end += 1 - timestamp.append(alignment[start:end]) + local_start = end - 1 + # find the possible front border for current token + while local_start >= start and prob[local_start][0] < math.log( + blank_thres) or prob[local_start][alignment[ + end - 1]] > math.log(thres): + alignment[local_start] = alignment[end - 1] + local_start -= 1 + cur_alignment = alignment[start:end] + timestamp.append(cur_alignment) start = end return timestamp def get_labformat(timestamp, subsample): begin = 0 + begin_time = 0 duration = 0 labformat = [] for idx, t in enumerate(timestamp): # 25ms frame_length,10ms hop_length, 1/subsample subsample = get_subsample(configs) # time duration - duration = len(t) * 0.01 * subsample + i = 0 + while t[i] == 0: + i += 1 + begin = i + dur = 0 + while i < len(t) and t[i] != 0: + i += 1 + dur += 1 + begin = begin_time + begin * 0.01 * subsample + duration = dur * 0.01 * subsample if idx < len(timestamp) - 1: print("{:.2f} {:.2f} {}".format(begin, begin + duration, char_dict[t[-1]])) labformat.append("{:.2f} {:.2f} {}\n".format( begin, begin + duration, char_dict[t[-1]])) - else: + else: # last token non_blank = 0 for i in t: if i != 0: @@ -97,7 +120,7 @@ def get_labformat(timestamp, subsample): char_dict[token])) labformat.append("{:.2f} {:.2f} {}\n".format( begin, begin + duration, char_dict[token])) - begin = begin + duration + begin_time += len(t) * 0.01 * subsample return labformat @@ -114,10 +137,19 @@ def get_labformat(timestamp, subsample): type=int, default=-1, help='gpu id for this rank, -1 for cpu') + parser.add_argument('--blank_thres', + default=0.999999, + type=float, + help='ctc blank thes') + parser.add_argument('--thres', + default=0.000001, + type=float, + help='ctc non blank thes') parser.add_argument('--checkpoint', required=True, help='checkpoint model') parser.add_argument('--dict', required=True, help='dict file') - parser.add_argument('--non_lang_syms', - help="non-linguistic symbol file. One symbol per line.") + parser.add_argument( + '--non_lang_syms', + help="non-linguistic symbol file. One symbol per line.") parser.add_argument('--result_file', required=True, help='alignment result file') @@ -165,6 +197,7 @@ def get_labformat(timestamp, subsample): ali_conf['filter_conf']['min_output_input_ratio'] = 0 ali_conf['speed_perturb'] = False ali_conf['spec_aug'] = False + ali_conf['spec_trim'] = False ali_conf['shuffle'] = False ali_conf['sort'] = False ali_conf['fbank_conf']['dither'] = 0.0 @@ -196,7 +229,6 @@ def get_labformat(timestamp, subsample): for batch_idx, batch in enumerate(ali_data_loader): print("#" * 80) key, feat, target, feats_length, target_length = batch - print(key) feat = feat.to(device) target = target.to(device) @@ -213,12 +245,11 @@ def get_labformat(timestamp, subsample): ctc_probs = ctc_probs.squeeze(0) target = target.squeeze(0) alignment = forced_align(ctc_probs, target) - print(alignment) fout.write('{} {}\n'.format(key[0], alignment)) if args.gen_praat: - timestamp = get_frames_timestamp(alignment) - print(timestamp) + timestamp = get_frames_timestamp(alignment, ctc_probs, + args.blank_thres, args.thres) subsample = get_subsample(configs) labformat = get_labformat(timestamp, subsample) From ee02522df8559defaf9a194fab851e35b74d2f43 Mon Sep 17 00:00:00 2001 From: Di Wu <1176705630@qq.com> Date: Fri, 25 Aug 2023 17:51:23 +0800 Subject: [PATCH 2/6] Update alignment.py --- wenet/bin/alignment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/bin/alignment.py b/wenet/bin/alignment.py index 44e13b6d1..9b9c25451 100644 --- a/wenet/bin/alignment.py +++ b/wenet/bin/alignment.py @@ -110,7 +110,7 @@ def get_labformat(timestamp, subsample): char_dict[t[-1]])) labformat.append("{:.2f} {:.2f} {}\n".format( begin, begin + duration, char_dict[t[-1]])) - else: # last token + else: # last token non_blank = 0 for i in t: if i != 0: From be2f00fca9f9a443933dd597b106ab037d0646e6 Mon Sep 17 00:00:00 2001 From: Di Wu <1176705630@qq.com> Date: Fri, 25 Aug 2023 17:52:17 +0800 Subject: [PATCH 3/6] Update alignment.sh --- tools/alignment.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/alignment.sh b/tools/alignment.sh index 3f5ac4314..636387c1e 100644 --- a/tools/alignment.sh +++ b/tools/alignment.sh @@ -43,7 +43,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then --result_file $ali_result \ --thres $thres \ --blank_thres $blank_thres \ - --gen_praat + --gen_praat fi From 7815761bda096247e45fd0137adb8dcf5c21f12a Mon Sep 17 00:00:00 2001 From: Di Wu <1176705630@qq.com> Date: Fri, 25 Aug 2023 17:54:07 +0800 Subject: [PATCH 4/6] Update alignment.py --- wenet/bin/alignment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/bin/alignment.py b/wenet/bin/alignment.py index 9b9c25451..ab40000a3 100644 --- a/wenet/bin/alignment.py +++ b/wenet/bin/alignment.py @@ -110,7 +110,7 @@ def get_labformat(timestamp, subsample): char_dict[t[-1]])) labformat.append("{:.2f} {:.2f} {}\n".format( begin, begin + duration, char_dict[t[-1]])) - else: # last token + else: # last token non_blank = 0 for i in t: if i != 0: From 46ebe6e268be1c07fcf13f47cbc3970d3102c66b Mon Sep 17 00:00:00 2001 From: Di Wu <1176705630@qq.com> Date: Fri, 25 Aug 2023 17:57:10 +0800 Subject: [PATCH 5/6] add ToDo --- tools/alignment.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/alignment.sh b/tools/alignment.sh index 636387c1e..a32f0650f 100644 --- a/tools/alignment.sh +++ b/tools/alignment.sh @@ -12,6 +12,7 @@ dict=data/dict/lang_char.txt dir=exp/ config=$dir/train.yaml # model trained with trim tail will get a better alignment result +# (Todo) cif/attention/rnnt alignment checkpoint=$dir/final.pt set=test From 2b92d0a7a13cd50db9de3ee9b3fd326bcbb3a90c Mon Sep 17 00:00:00 2001 From: Di Wu <1176705630@qq.com> Date: Mon, 28 Aug 2023 16:38:35 +0800 Subject: [PATCH 6/6] bug fix --- wenet/bin/alignment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wenet/bin/alignment.py b/wenet/bin/alignment.py index ab40000a3..c37eac4b6 100644 --- a/wenet/bin/alignment.py +++ b/wenet/bin/alignment.py @@ -75,9 +75,9 @@ def get_frames_timestamp(alignment, end += 1 local_start = end - 1 # find the possible front border for current token - while local_start >= start and prob[local_start][0] < math.log( + while local_start >= start and (prob[local_start][0] < math.log( blank_thres) or prob[local_start][alignment[ - end - 1]] > math.log(thres): + end - 1]] > math.log(thres)): alignment[local_start] = alignment[end - 1] local_start -= 1 cur_alignment = alignment[start:end]