Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[paraformer] separate cif and paraformer #1795

Merged
merged 1 commit into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 0 additions & 88 deletions examples/aishell/cif/conf/train_cif_conformer.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Performance Record

## Standard CIF Conformer Result(TODO)


## Paraformer like CIF Conformer Result
## Paraformer, Conformer Result

* Feature info: using fbank feature, dither, cmvn, online speed perturb
* Training info: lr 0.002, batch size 16, 4 gpu, acc_grad 4, 240 epochs, dither 0.1
Expand All @@ -16,7 +14,7 @@
| cif greedy search | 4.41 | 4.92 |
| cif beam search | 4.35 | 4.86 |

## Conformer CIF DecoderSANM Result(Deprecated)
## Paraformer, Conformer DecoderSANM Result(Deprecated)

* Feature info: using fbank feature, dither, cmvn, online speed perturb
* Training info: lr 0.002, batch size 16, 4 gpu, acc_grad 4, 240 epochs, dither 0.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ decoder_conf:
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
input_layer: 'none'
src_attention: true

cif_predictor: predictor_v1
paraformer: true
cif_predictor_conf:
idim: 256
threshold: 1.0
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@ num_utts_per_shard=1000

train_set=train
# Optional train_config
# conf/train_cif_conformer.yaml: standard non streaming CIF,
# conf/train_paraformer.yaml: paraformer like CIF, apply encoder-decoder attention
train_config=conf/train_cif_conformer.yaml
train_config=conf/train_paraformer.yaml
cmvn=true
dir=exp/cif_conformer
dir=exp/paraformer
checkpoint=

# use average_checkpoint will get better result
Expand All @@ -52,8 +51,8 @@ average_num=20
#decode_modes="ctc_greedy_search ctc_prefix_beam_search cif_greedy_search cif_beam_search"
# Since the Predictor Loss also plays an important role in the training of the CIF models,
# the performance of the predictor cannot be used by using the CTC-related decoding methods,
# so we strongly recommend that you use the 'cif_greedy_search' and 'cif_beam_search'.
decode_modes="cif_greedy_search cif_beam_search"
# so we strongly recommend that you use the 'paraformer_greedy_search' and 'paraformer_beam_search'.
decode_modes="paraformer_greedy_search paraformer_beam_search"

. tools/parse_options.sh || exit 1;

Expand Down
File renamed without changes.
File renamed without changes.
25 changes: 12 additions & 13 deletions wenet/bin/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,12 @@
from torch.utils.data import DataLoader

from wenet.dataset.dataset import Dataset
from wenet.paraformer.search.beam_search import build_beam_search
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols
from wenet.utils.config import override_config
from wenet.utils.init_model import init_model

from wenet.cif.search.beam_search import build_beam_search


def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
Expand Down Expand Up @@ -69,8 +68,8 @@ def get_args():
'rnnt_greedy_search', 'rnnt_beam_search',
'rnnt_beam_attn_rescoring',
'ctc_beam_td_attn_rescoring', 'hlg_onebest',
'hlg_rescore', 'cif_greedy_search',
'cif_beam_search',
'hlg_rescore', 'paraformer_greedy_search',
'paraformer_beam_search',
],
default='attention',
help='decoding mode')
Expand Down Expand Up @@ -166,7 +165,7 @@ def main():
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring',
'cif_beam_search', ] and args.batch_size > 1:
'paraformer_beam_search', ] and args.batch_size > 1:
logging.fatal(
'decoding mode {} must be running with batch_size == 1'.format(
args.mode))
Expand Down Expand Up @@ -225,10 +224,10 @@ def main():
model.eval()

# Build BeamSearchCIF object
if args.mode == 'cif_beam_search':
cif_beam_search = build_beam_search(model, args, device)
if args.mode == 'paraformer_beam_search':
paraformer_beam_search = build_beam_search(model, args, device)
else:
cif_beam_search = None
paraformer_beam_search = None

with torch.no_grad(), open(args.result_file, 'w') as fout:
for batch_idx, batch in enumerate(test_data_loader):
Expand Down Expand Up @@ -355,16 +354,16 @@ def main():
hlg=args.hlg,
word=args.word,
symbol_table=symbol_table)
elif args.mode == 'cif_beam_search':
hyps = model.cif_beam_search(
elif args.mode == 'paraformer_beam_search':
hyps = model.paraformer_beam_search(
feats,
feats_lengths,
beam_search=cif_beam_search,
beam_search=paraformer_beam_search,
decoding_chunk_size=args.decoding_chunk_size,
num_decoding_left_chunks=args.num_decoding_left_chunks,
simulate_streaming=args.simulate_streaming)
elif args.mode == 'cif_greedy_search':
hyps = model.cif_greedy_search(
elif args.mode == 'paraformer_greedy_search':
hyps = model.paraformer_greedy_search(
feats,
feats_lengths,
decoding_chunk_size=args.decoding_chunk_size,
Expand Down
13 changes: 7 additions & 6 deletions wenet/cif/cif_model.py → wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from wenet.cif.predictor import MAELoss
from wenet.cif.search.beam_search import Hypothesis
from wenet.paraformer.search.beam_search import Hypothesis
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
Expand All @@ -29,9 +29,10 @@
from wenet.utils.mask import make_pad_mask


class CIFModel(ASRModel):
""" Continuous Integrate-and-Fire model
see https://arxiv.org/pdf/1905.11235.pdf
class Paraformer(ASRModel):
""" Paraformer: Fast and Accurate Parallel Transformer for
Non-autoregressive End-to-End Speech Recognition
see https://arxiv.org/pdf/2206.08317.pdf
"""

def __init__(
Expand Down Expand Up @@ -169,7 +170,7 @@ def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens,
def recognize(self):
raise NotImplementedError

def cif_greedy_search(
def paraformer_greedy_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
Expand Down Expand Up @@ -249,7 +250,7 @@ def cif_greedy_search(
hyps.append(token_int)
return hyps

def cif_beam_search(
def paraformer_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

import torch

from wenet.cif.utils import end_detect
from wenet.cif.search.ctc import CTCPrefixScorer
from wenet.cif.search.scorer_interface import ScorerInterface, \
from wenet.paraformer.utils import end_detect
from wenet.paraformer.search.ctc import CTCPrefixScorer
from wenet.paraformer.search.scorer_interface import ScorerInterface, \
PartialScorerInterface


Expand Down
5 changes: 3 additions & 2 deletions wenet/cif/search/ctc.py → wenet/paraformer/search/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import numpy as np
import torch

from wenet.cif.search.ctc_prefix_score import CTCPrefixScore, CTCPrefixScoreTH
from wenet.cif.search.scorer_interface import BatchPartialScorerInterface
from wenet.paraformer.search.ctc_prefix_score import CTCPrefixScore
from wenet.paraformer.search.ctc_prefix_score import CTCPrefixScoreTH
from wenet.paraformer.search.scorer_interface import BatchPartialScorerInterface


class CTCPrefixScorer(BatchPartialScorerInterface):
Expand Down
File renamed without changes.
16 changes: 8 additions & 8 deletions wenet/utils/init_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder
from wenet.squeezeformer.encoder import SqueezeformerEncoder
from wenet.efficient_conformer.encoder import EfficientConformerEncoder
from wenet.cif.cif_model import CIFModel
from wenet.paraformer.paraformer import Paraformer
from wenet.cif.predictor import Predictor
from wenet.utils.cmvn import load_cmvn

Expand Down Expand Up @@ -104,14 +104,14 @@ def init_model(configs):
joint=joint,
ctc=ctc,
**configs['model_conf'])
elif 'cif_predictor' in configs:
elif 'paraformer' in configs:
predictor = Predictor(**configs['cif_predictor_conf'])
model = CIFModel(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs['model_conf'])
model = Paraformer(vocab_size=vocab_size,
encoder=encoder,
decoder=decoder,
ctc=ctc,
predictor=predictor,
**configs['model_conf'])
else:
model = ASRModel(vocab_size=vocab_size,
encoder=encoder,
Expand Down