From b83a68ac4f02c764479a28d291fb21b26543ff12 Mon Sep 17 00:00:00 2001 From: "xiang.lv" Date: Wed, 26 Oct 2022 07:19:41 -0400 Subject: [PATCH] add hlg decode --- examples/aishell/s0/README.md | 4 +- examples/aishell/s0/run.sh | 46 ++++++ tools/k2/make_hlg.sh | 36 +++++ tools/k2/prepare_char.py | 258 +++++++++++++++++++++++++++++++++ wenet/bin/recognize.py | 47 +++++- wenet/transformer/asr_model.py | 170 ++++++++++++++++++++++ 6 files changed, 559 insertions(+), 2 deletions(-) create mode 100755 tools/k2/make_hlg.sh create mode 100644 tools/k2/prepare_char.py diff --git a/examples/aishell/s0/README.md b/examples/aishell/s0/README.md index 96675cbf4..436da8168 100644 --- a/examples/aishell/s0/README.md +++ b/examples/aishell/s0/README.md @@ -19,7 +19,7 @@ * Feature info: using fbank feature, dither=1.0, cmvn, oneline speed perturb * Training info: lr 0.001, batch size 16, 8 gpu, acc_grad 1, 360 epochs -* Decoding info: ctc_weight 0.3, reverse_weight 0.5 average_num 30 +* Decoding info: ctc_weight 0.3, reverse_weight 0.5 average_num 30, lm_scale 0.7, decoder_scale 0.1, r_decoder_scale 0.7 * Git hash: 5a1342312668e7a5abb83aed1e53256819cebf95 | decoding mode/chunk size | full | 16 | @@ -28,6 +28,8 @@ | ctc prefix beam search | 5.17 | 5.81 | | attention rescoring | 4.63 | 5.05 | | LM + attention rescoring | 4.40 | 4.75 | +| HLG(k2 LM) | 4.81 | 5.27 | +| HLG(k2 LM) + attention rescoring | 4.32 | 4.70 | ## Unified Conformer Result diff --git a/examples/aishell/s0/run.sh b/examples/aishell/s0/run.sh index e18b77c84..e136b1216 100644 --- a/examples/aishell/s0/run.sh +++ b/examples/aishell/s0/run.sh @@ -239,4 +239,50 @@ if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then # Please see $dir/lm_with_runtime for wer fi +# Optionally, you can decode with k2 hlg +if [ ${stage} -le 8 ] && [ ${stop_stage} -ge 8 ]; then + if [ ! -f data/local/lm/lm.arpa ]; then + echo "Please run prepare dict and train lm in Stage 7" || exit 1; + fi + + # 8.1 Build decoding HLG + required="data/local/hlg/HLG.pt data/local/hlg/words.txt" + for f in $required; do + if [ ! -f $f ]; then + tools/k2/make_hlg.sh data/local/dict/ data/local/lm/ data/local/hlg + break + fi + done + + # 8.2 Decode using HLG + decoding_chunk_size= + lm_scale=0.7 + decoder_scale=0.1 + r_decoder_scale=0.7 + for mode in hlg_onebest hlg_rescore; do + { + test_dir=$dir/test_${mode} + mkdir -p $test_dir + python wenet/bin/recognize.py --gpu 0 \ + --mode $mode \ + --config $dir/train.yaml \ + --data_type $data_type \ + --test_data data/test/data.list \ + --checkpoint $decode_checkpoint \ + --beam_size 10 \ + --batch_size 16 \ + --penalty 0.0 \ + --dict $dict \ + --word data/local/hlg/words.txt \ + --hlg data/local/hlg/HLG.pt \ + --lm_scale $lm_scale \ + --decoder_scale $decoder_scale \ + --r_decoder_scale $r_decoder_scale \ + --result_file $test_dir/text \ + ${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size} + python tools/compute-wer.py --char=1 --v=1 \ + data/test/text $test_dir/text > $test_dir/wer + } + done +fi diff --git a/tools/k2/make_hlg.sh b/tools/k2/make_hlg.sh new file mode 100755 index 000000000..73f0b6a16 --- /dev/null +++ b/tools/k2/make_hlg.sh @@ -0,0 +1,36 @@ +#!/bin/bash +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# Copyright 2022 Ximalaya Speech Team (author: Xiang Lyu) + +lexion_dir=$1 +lm_dir=$2 +tgt_dir=$3 + +# For k2 installation, please refer to https://github.com/k2-fsa/k2/ +python -c "import k2; print(k2.__file__)" +python -c "import torch; import _k2; print(_k2.__file__)" + +# Prepare necessary icefall scripts +if [ ! -d tools/k2/icefall ]; then + git clone --depth 1 https://github.com/k2-fsa/icefall.git tools/k2/icefall +fi +pip install -r tools/k2/icefall/requirements.txt +export PYTHONPATH=`pwd`/tools/k2/icefall:`pwd`/tools/k2/icefall/egs/aishell/ASR/local:$PYTHONPATH + +# 8.1 Prepare char based lang +mkdir -p $tgt_dir +python tools/k2/prepare_char.py $lexion_dir/units.txt $lm_dir/wordlist $tgt_dir +echo "Compile lexicon L.pt L_disambig.pt succeeded" + +# 8.2 Prepare G +mkdir -p data/lm +python -m kaldilm \ + --read-symbol-table="$tgt_dir/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $lm_dir/lm.arpa > data/lm/G_3_gram.fst.txt + +# 8.3 Compile HLG +python tools/k2/icefall/egs/aishell/ASR/local/compile_hlg.py --lang-dir $tgt_dir +echo "Compile decoding graph HLG.pt succeeded" \ No newline at end of file diff --git a/tools/k2/prepare_char.py b/tools/k2/prepare_char.py new file mode 100644 index 000000000..6e05042c4 --- /dev/null +++ b/tools/k2/prepare_char.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# Copyright 2022 Ximalaya Speech Team (author: Xiang Lyu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" + +This script generates the following files in the directory sys.argv[3]: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt + - words.txt +""" + +import sys +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [ + token2id[i] if i in token2id else token2id[""] for i in pieces + ] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: + """Check if all the given tokens are in token symbol table. + + Args: + token_sym_table: + Token symbol table that contains all the valid tokens. + tokens: + A list of tokens. + Returns: + Return True if there is any token not in the token_sym_table, + otherwise False. + """ + for tok in tokens: + if tok not in token_sym_table: + return True + return False + + +def generate_lexicon( + token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: + """Generate a lexicon from a word list and token_sym_table. + + Args: + token_sym_table: + Token symbol table that mapping token to token ids. + words: + A list of strings representing words. + Returns: + Return a dict whose keys are words and values are the corresponding + tokens. + """ + lexicon = [] + for word in words: + chars = list(word.strip(" \t")) + if contain_oov(token_sym_table, chars): + continue + lexicon.append((word, chars)) + + # The OOV word is + lexicon.append(("", [""])) + return lexicon + + +def generate_tokens(text_file: str) -> Dict[str, int]: + """Generate tokens from the given text file. + + Args: + text_file: + A file that contains text lines to generate tokens. + Returns: + Return a dict whose keys are tokens and values are token ids ranged + from 0 to len(keys) - 1. + """ + token2id: Dict[str, int] = dict() + with open(text_file, "r", encoding="utf-8") as f: + for line in f: + char, index = line.replace('\n', '').split() + assert char not in token2id + token2id[char] = int(index) + assert token2id[''] == 0 + return token2id + + +def generate_words(text_file: str) -> Dict[str, int]: + """Generate words from the given text file. + + Args: + text_file: + A file that contains text lines to generate words. + Returns: + Return a dict whose keys are words and values are words ids ranged + from 0 to len(keys) - 1. + """ + words = [] + with open(text_file, "r", encoding="utf-8") as f: + for line in f: + word = line.replace('\n', '') + assert word not in words + words.append(word) + words.sort() + + # We put '' '' at begining of word2id + # '#0', '', '' at end of word2id + words = [word for word in words + if word not in ['', '', '#0', '', '']] + words.insert(0, '') + words.insert(1, '') + words.append('#0') + words.append('') + words.append('') + word2id = {j: i for i, j in enumerate(words)} + return word2id + + +def main(): + token2id = generate_tokens(sys.argv[1]) + word2id = generate_words(sys.argv[2]) + tgt_dir = Path(sys.argv[3]) + + words = [word for word in word2id.keys() + if word not in + ["", "!SIL", "", "", "#0", "", ""]] + lexicon = generate_lexicon(token2id, words) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + next_token_id = max(token2id.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token2id + token2id[disambig] = next_token_id + next_token_id += 1 + + write_mapping(tgt_dir / "tokens.txt", token2id) + write_mapping(tgt_dir / "words.txt", word2id) + write_lexicon(tgt_dir / "lexicon.txt", lexicon) + write_lexicon(tgt_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token2id, + word2id=word2id, + ) + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token2id, + word2id=word2id, + need_self_loops=True, + ) + torch.save(L.as_dict(), tgt_dir / "L.pt") + torch.save(L_disambig.as_dict(), tgt_dir / "L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/wenet/bin/recognize.py b/wenet/bin/recognize.py index c8b939741..03b5dfd42 100644 --- a/wenet/bin/recognize.py +++ b/wenet/bin/recognize.py @@ -64,7 +64,8 @@ def get_args(): 'attention', 'ctc_greedy_search', 'ctc_prefix_beam_search', 'attention_rescoring', 'rnnt_greedy_search', 'rnnt_beam_search', - 'rnnt_beam_attn_rescoring', 'ctc_beam_td_attn_rescoring' + 'rnnt_beam_attn_rescoring', 'ctc_beam_td_attn_rescoring', + 'hlg_onebest', 'hlg_rescore' ], default='attention', help='decoding mode') @@ -127,6 +128,27 @@ def get_args(): type=str, help='used to connect the output characters') + parser.add_argument('--word', + default='', + type=str, + help='word file, only used for hlg decode') + parser.add_argument('--hlg', + default='', + type=str, + help='hlg file, only used for hlg decode') + parser.add_argument('--lm_scale', + type=float, + default=0.0, + help='lm scale for hlg attention rescore decode') + parser.add_argument('--decoder_scale', + type=float, + default=0.0, + help='lm scale for hlg attention rescore decode') + parser.add_argument('--r_decoder_scale', + type=float, + default=0.0, + help='lm scale for hlg attention rescore decode') + args = parser.parse_args() print(args) return args @@ -298,6 +320,29 @@ def main(): simulate_streaming=args.simulate_streaming, reverse_weight=args.reverse_weight) hyps = [hyp] + elif args.mode == 'hlg_onebest': + hyps = model.hlg_onebest( + feats, + feats_lengths, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming, + hlg=args.hlg, + word=args.word, + symbol_table=symbol_table) + elif args.mode == 'hlg_rescore': + hyps = model.hlg_rescore( + feats, + feats_lengths, + decoding_chunk_size=args.decoding_chunk_size, + num_decoding_left_chunks=args.num_decoding_left_chunks, + simulate_streaming=args.simulate_streaming, + lm_scale=args.lm_scale, + decoder_scale=args.decoder_scale, + r_decoder_scale=args.r_decoder_scale, + hlg=args.hlg, + word=args.word, + symbol_table=symbol_table) for i, key in enumerate(keys): content = [] for w in hyps[i]: diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index 5dd151e49..367c9189a 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -20,6 +20,14 @@ from torch.nn.utils.rnn import pad_sequence +try: + import k2 + from icefall.utils import get_texts + from icefall.decode import get_lattice, Nbest, one_best_decoding +except ImportError: + print('Failed to import k2 and icefall. \ + Notice that they are necessary for hlg_onebest and hlg_rescore') + from wenet.transformer.ctc import CTC from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.encoder import TransformerEncoder @@ -538,6 +546,168 @@ def attention_rescoring( best_index = i return hyps[best_index][0], best_score + def load_hlg_resource_if_necessary(self, hlg, word): + if not hasattr(self, 'hlg'): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.hlg = k2.Fsa.from_dict(torch.load(hlg, map_location=device)) + if not hasattr(self.hlg, "lm_scores"): + self.hlg.lm_scores = self.hlg.scores.clone() + if not hasattr(self, 'word_table'): + self.word_table = {} + with open(word, 'r') as fin: + for line in fin: + arr = line.strip().split() + assert len(arr) == 2 + self.word_table[int(arr[1])] = arr[0] + + @torch.no_grad() + def hlg_onebest( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + hlg: str = '', + word: str = '', + symbol_table: Dict[str, int] = None, + ) -> List[int]: + self.load_hlg_resource_if_necessary(hlg, word) + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + supervision_segments = torch.stack( + (torch.arange(len(encoder_mask)), + torch.zeros(len(encoder_mask)), + encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), 1,).to(torch.int32) + lattice = get_lattice( + nnet_output=ctc_probs, + decoding_graph=self.hlg, + supervision_segments=supervision_segments, + search_beam=20, + output_beam=7, + min_active_states=30, + max_active_states=10000, + subsampling_factor=4) + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + hyps = get_texts(best_path) + hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] for i in hyps] + return hyps + + @torch.no_grad() + def hlg_rescore( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + decoding_chunk_size: int = -1, + num_decoding_left_chunks: int = -1, + simulate_streaming: bool = False, + lm_scale: float = 0, + decoder_scale: float = 0, + r_decoder_scale: float = 0, + hlg: str = '', + word: str = '', + symbol_table: Dict[str, int] = None, + ) -> List[int]: + self.load_hlg_resource_if_necessary(hlg, word) + device = speech.device + encoder_out, encoder_mask = self._forward_encoder( + speech, speech_lengths, decoding_chunk_size, + num_decoding_left_chunks, + simulate_streaming) # (B, maxlen, encoder_dim) + ctc_probs = self.ctc.log_softmax( + encoder_out) # (1, maxlen, vocab_size) + supervision_segments = torch.stack( + (torch.arange(len(encoder_mask)), + torch.zeros(len(encoder_mask)), + encoder_mask.squeeze(dim=1).sum(dim=1).cpu()), 1,).to(torch.int32) + lattice = get_lattice( + nnet_output=ctc_probs, + decoding_graph=self.hlg, + supervision_segments=supervision_segments, + search_beam=20, + output_beam=7, + min_active_states=30, + max_active_states=10000, + subsampling_factor=4) + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=100, + use_double_scores=True, + nbest_scale=0.5,) + nbest = nbest.intersect(lattice) + assert hasattr(nbest.fsa, "lm_scores") + assert hasattr(nbest.fsa, "tokens") + assert isinstance(nbest.fsa.tokens, torch.Tensor) + + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + tokens = tokens.remove_values_leq(0) + hyps = tokens.tolist() + + # cal attention_score + hyps_pad = pad_sequence([ + torch.tensor(hyp, device=device, dtype=torch.long) + for hyp in hyps + ], True, self.ignore_id) # (beam_size, max_hyps_len) + ori_hyps_pad = hyps_pad + hyps_lens = torch.tensor([len(hyp) for hyp in hyps], + device=device, + dtype=torch.long) # (beam_size,) + hyps_pad, _ = add_sos_eos(hyps_pad, self.sos, self.eos, self.ignore_id) + hyps_lens = hyps_lens + 1 # Add at begining + encoder_out_repeat = [] + tot_scores = nbest.tot_scores() + repeats = [tot_scores[i].shape[0] for i in range(tot_scores.dim0)] + for i in range(len(encoder_out)): + encoder_out_repeat.append(encoder_out[i: i + 1].repeat(repeats[i], 1, 1)) + encoder_out = torch.concat(encoder_out_repeat, dim=0) + encoder_mask = torch.ones(encoder_out.size(0), + 1, + encoder_out.size(1), + dtype=torch.bool, + device=device) + # used for right to left decoder + r_hyps_pad = reverse_pad_list(ori_hyps_pad, hyps_lens, self.ignore_id) + r_hyps_pad, _ = add_sos_eos(r_hyps_pad, self.sos, self.eos, + self.ignore_id) + reverse_weight = 0.5 + decoder_out, r_decoder_out, _ = self.decoder( + encoder_out, encoder_mask, hyps_pad, hyps_lens, r_hyps_pad, + reverse_weight) # (beam_size, max_hyps_len, vocab_size) + decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1) + decoder_out = decoder_out + # r_decoder_out will be 0.0, if reverse_weight is 0.0 or decoder is a + # conventional transformer decoder. + r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1) + r_decoder_out = r_decoder_out + + decoder_scores = torch.tensor([sum([decoder_out[i, j, hyps[i][j]] + for j in range(len(hyps[i]))]) + for i in range(len(hyps))], device=device) + r_decoder_scores = [] + for i in range(len(hyps)): + score = 0 + for j in range(len(hyps[i])): + score += r_decoder_out[i, len(hyps[i]) - j - 1, hyps[i][j]] + score += r_decoder_out[i, len(hyps[i]), self.eos] + r_decoder_scores.append(score) + r_decoder_scores = torch.tensor(r_decoder_scores, device=device) + + am_scores = nbest.compute_am_scores() + ngram_lm_scores = nbest.compute_lm_scores() + tot_scores = am_scores.values + lm_scale * ngram_lm_scores.values + \ + decoder_scale * decoder_scores + r_decoder_scale * r_decoder_scores + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + hyps = get_texts(best_path) + hyps = [[symbol_table[k] for j in i for k in self.word_table[j]] for i in hyps] + return hyps + @torch.jit.export def subsampling_rate(self) -> int: """ Export interface for c++ call, return subsampling_rate of the