Skip to content

Commit

Permalink
[paraformer] separate cif and paraformer (#1795)
Browse files Browse the repository at this point in the history
  • Loading branch information
robin1001 authored Mar 30, 2023
1 parent b4086db commit 4d6c144
Show file tree
Hide file tree
Showing 16 changed files with 40 additions and 131 deletions.
88 changes: 0 additions & 88 deletions examples/aishell/cif/conf/train_cif_conformer.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Performance Record

## Standard CIF Conformer Result(TODO)


## Paraformer like CIF Conformer Result
## Paraformer, Conformer Result

* Feature info: using fbank feature, dither, cmvn, online speed perturb
* Training info: lr 0.002, batch size 16, 4 gpu, acc_grad 4, 240 epochs, dither 0.1
Expand All @@ -16,7 +14,7 @@
| cif greedy search | 4.41 | 4.92 |
| cif beam search | 4.35 | 4.86 |

## Conformer CIF DecoderSANM Result(Deprecated)
## Paraformer, Conformer DecoderSANM Result(Deprecated)

* Feature info: using fbank feature, dither, cmvn, online speed perturb
* Training info: lr 0.002, batch size 16, 4 gpu, acc_grad 4, 240 epochs, dither 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
input_layer: 'none'
src_attention: true

cif_predictor: predictor_v1
paraformer: true
cif_predictor_conf:
idim: 256
threshold: 1.0
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@ num_utts_per_shard=1000

train_set=train
# Optional train_config
# conf/train_cif_conformer.yaml: standard non streaming CIF,
# conf/train_paraformer.yaml: paraformer like CIF, apply encoder-decoder attention
train_config=conf/train_cif_conformer.yaml
train_config=conf/train_paraformer.yaml
cmvn=true
dir=exp/cif_conformer
dir=exp/paraformer
checkpoint=

# use average_checkpoint will get better result
Expand All @@ -52,8 +51,8 @@ average_num=20
#decode_modes="ctc_greedy_search ctc_prefix_beam_search cif_greedy_search cif_beam_search"
# Since the Predictor Loss also plays an important role in the training of the CIF models,
# the performance of the predictor cannot be used by using the CTC-related decoding methods,
# so we strongly recommend that you use the 'cif_greedy_search' and 'cif_beam_search'.
decode_modes="cif_greedy_search cif_beam_search"
# so we strongly recommend that you use the 'paraformer_greedy_search' and 'paraformer_beam_search'.
decode_modes="paraformer_greedy_search paraformer_beam_search"

. tools/parse_options.sh || exit 1;

Expand Down
File renamed without changes.
File renamed without changes.
25 changes: 12 additions & 13 deletions wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
from torch.utils.data import DataLoader

from wenet.dataset.dataset import Dataset
from wenet.paraformer.search.beam_search import build_beam_search
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.config import override_config
from wenet.utils.init_model import init_model

from wenet.cif.search.beam_search import build_beam_search


def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
Expand Down Expand Up @@ -69,8 +68,8 @@ def get_args():
'rnnt_greedy_search', 'rnnt_beam_search',
'rnnt_beam_attn_rescoring',
'ctc_beam_td_attn_rescoring', 'hlg_onebest',
'hlg_rescore', 'cif_greedy_search',
'cif_beam_search',
'hlg_rescore', 'paraformer_greedy_search',
'paraformer_beam_search',
],
default='attention',
help='decoding mode')
Expand Down Expand Up @@ -166,7 +165,7 @@ def main():
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring',
'cif_beam_search', ] and args.batch_size > 1:
'paraformer_beam_search', ] and args.batch_size > 1:
logging.fatal(
'decoding mode {} must be running with batch_size == 1'.format(
args.mode))
Expand Down Expand Up @@ -225,10 +224,10 @@ def main():
model.eval()

# Build BeamSearchCIF object
if args.mode == 'cif_beam_search':
cif_beam_search = build_beam_search(model, args, device)
if args.mode == 'paraformer_beam_search':
paraformer_beam_search = build_beam_search(model, args, device)
else:
cif_beam_search = None
paraformer_beam_search = None

with torch.no_grad(), open(args.result_file, 'w') as fout:
for batch_idx, batch in enumerate(test_data_loader):
Expand Down Expand Up @@ -355,16 +354,16 @@ def main():
hlg=args.hlg,
word=args.word,
symbol_table=symbol_table)
elif args.mode == 'cif_beam_search':
hyps = model.cif_beam_search(
elif args.mode == 'paraformer_beam_search':
hyps = model.paraformer_beam_search(
feats,
feats_lengths,
beam_search=cif_beam_search,
beam_search=paraformer_beam_search,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming)
elif args.mode == 'cif_greedy_search':
hyps = model.cif_greedy_search(
elif args.mode == 'paraformer_greedy_search':
hyps = model.paraformer_greedy_search(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
Expand Down
13 changes: 7 additions & 6 deletions wenet/cif/cif_model.py → wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from wenet.cif.predictor import MAELoss
from wenet.cif.search.beam_search import Hypothesis
from wenet.paraformer.search.beam_search import Hypothesis
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
Expand All @@ -29,9 +29,10 @@
from wenet.utils.mask import make_pad_mask


class CIFModel(ASRModel):
""" Continuous Integrate-and-Fire model
see https://arxiv.org/pdf/1905.11235.pdf
class Paraformer(ASRModel):
""" Paraformer: Fast and Accurate Parallel Transformer for
Non-autoregressive End-to-End Speech Recognition
see https://arxiv.org/pdf/2206.08317.pdf
"""

def __init__(
Expand Down Expand Up @@ -169,7 +170,7 @@ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens,
def recognize(self):
raise NotImplementedError

def cif_greedy_search(
def paraformer_greedy_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
Expand Down Expand Up @@ -249,7 +250,7 @@ def cif_greedy_search(
hyps.append(token_int)
return hyps

def cif_beam_search(
def paraformer_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

import torch

from wenet.cif.utils import end_detect
from wenet.cif.search.ctc import CTCPrefixScorer
from wenet.cif.search.scorer_interface import ScorerInterface, \
from wenet.paraformer.utils import end_detect
from wenet.paraformer.search.ctc import CTCPrefixScorer
from wenet.paraformer.search.scorer_interface import ScorerInterface, \
PartialScorerInterface


Expand Down
5 changes: 3 additions & 2 deletions wenet/cif/search/ctc.py → wenet/paraformer/search/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import numpy as np
import torch

from wenet.cif.search.ctc_prefix_score import CTCPrefixScore, CTCPrefixScoreTH
from wenet.cif.search.scorer_interface import BatchPartialScorerInterface
from wenet.paraformer.search.ctc_prefix_score import CTCPrefixScore
from wenet.paraformer.search.ctc_prefix_score import CTCPrefixScoreTH
from wenet.paraformer.search.scorer_interface import BatchPartialScorerInterface


class CTCPrefixScorer(BatchPartialScorerInterface):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
16 changes: 8 additions & 8 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder
from wenet.squeezeformer.encoder import SqueezeformerEncoder
from wenet.efficient_conformer.encoder import EfficientConformerEncoder
from wenet.cif.cif_model import CIFModel
from wenet.paraformer.paraformer import Paraformer
from wenet.cif.predictor import Predictor
from wenet.utils.cmvn import load_cmvn

Expand Down Expand Up @@ -104,14 +104,14 @@ def init_model(configs):
joint=joint,
ctc=ctc,
**configs['model_conf'])
elif 'cif_predictor' in configs:
elif 'paraformer' in configs:
predictor = Predictor(**configs['cif_predictor_conf'])
model = CIFModel(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs['model_conf'])
model = Paraformer(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs['model_conf'])
else:
model = ASRModel(vocab_size=vocab_size,
encoder=encoder,
Expand Down

0 comments on commit 4d6c144

Please sign in to comment.