Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[paraformer] timestamp #2277

Merged
merged 8 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/wenet/text/test_paraformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def paraformer_tokenizer(request):
_download_fn(download_root, seg_dict)

config_name = 'config.yaml'
_download_fn(download_root, config_name)
_download_fn(download_root, config_name, version='v1.2.4')
with open(os.path.join(download_root, config_name), 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
wenet_units = os.path.join(download_root, 'units.txt')
Expand Down
21 changes: 12 additions & 9 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import torchaudio.compliance.kaldi as kaldi

from wenet.cli.hub import Hub
from wenet.paraformer.search import paraformer_beautify_result, paraformer_greedy_search
from wenet.paraformer.search import (gen_timestamps_from_peak,
paraformer_beautify_result,
paraformer_greedy_search)
from wenet.text.paraformer_tokenizer import ParaformerTokenizer


Expand Down Expand Up @@ -45,28 +47,29 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
dtype=torch.int64,
device=feats.device)

decoder_out, token_num = self.model.forward_paraformer(
decoder_out, token_num, tp_alphas = self.model.forward_paraformer(
feats, feats_lens)

res = paraformer_greedy_search(decoder_out, token_num)[0]

cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num)
res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0]
result = {}
result['confidence'] = res.confidence
result['text'] = paraformer_beautify_result(
self.tokenizer.detokenize(res.tokens)[1])
if tokens_info:
tokens_info = []
times = gen_timestamps_from_peak(res.times,
num_frames=tp_alphas.size(1),
frame_rate=0.02)

for i, x in enumerate(res.tokens):
tokens_info.append({
'token': self.tokenizer.char_dict[x],
# TODO(Mddct): support times
# 'start': 0,
# 'end': 0,
'start': times[i][0],
'end': times[i][1],
'confidence': res.tokens_confidence[i]
})
result['tokens'] = tokens_info

# result = ''.join(hyp)
return result

def align(self, audio_file: str, label: str) -> dict:
Expand Down
66 changes: 48 additions & 18 deletions wenet/paraformer/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@

class Cif(nn.Module):

def __init__(self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0,
tail_threshold=0.45,
residual=True,
cnn_groups=0):
def __init__(
self,
idim,
l_order,
r_order,
threshold=1.0,
dropout=0.1,
smooth_factor=1.0,
noise_threshold=0.0,
tail_threshold=0.45,
residual=True,
cnn_groups=0,
):
super().__init__()

self.pad = nn.ConstantPad1d((l_order, r_order), 0.0)
Expand All @@ -50,13 +52,15 @@ def __init__(self,
self.tail_threshold = tail_threshold
self.residual = residual

def forward(self,
hidden,
target_label: Optional[torch.Tensor] = None,
mask: torch.Tensor = torch.tensor(0),
ignore_id: int = -1,
mask_chunk_predictor: Optional[torch.Tensor] = None,
target_label_length: Optional[torch.Tensor] = None):
def forward(
self,
hidden,
target_label: Optional[torch.Tensor] = None,
mask: torch.Tensor = torch.tensor(0),
ignore_id: int = -1,
mask_chunk_predictor: Optional[torch.Tensor] = None,
target_label_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
Expand Down Expand Up @@ -94,6 +98,7 @@ def forward(self,
alphas,
token_num,
mask=mask)

acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)

if target_length is None and self.tail_threshold > 0.0:
Expand Down Expand Up @@ -217,6 +222,31 @@ def forward(self, token_length, pre_token_length):
return loss


def cif_without_hidden(alphas: torch.Tensor, threshold: float):
# https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/models/predictor/cif.py#L187
batch_size, len_time = alphas.size()

# loop varss
integrate = torch.zeros([batch_size], device=alphas.device)
# intermediate vars along time
list_fires = []

for t in range(len_time):
alpha = alphas[:, t]

integrate += alpha
list_fires.append(integrate)

fire_place = integrate >= threshold
integrate = torch.where(
fire_place, integrate -
torch.ones([batch_size], device=alphas.device) * threshold,
integrate)

fires = torch.stack(list_fires, 1)
return fires


def cif(hidden: torch.Tensor, alphas: torch.Tensor, threshold: float):
batch_size, len_time, hidden_size = hidden.size()

Expand Down
48 changes: 33 additions & 15 deletions wenet/paraformer/convert_paraformer_to_wenet_config_and_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@
from pathlib import Path
import shutil
import urllib.request
import torch
from tqdm import tqdm
from typing import Dict, List, Optional, Tuple

import yaml

from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.init_model import init_model


def _load_paraformer_cmvn(cmvn_file) -> Tuple[List, List]:
with open(cmvn_file, 'r', encoding='utf-8') as f:
Expand Down Expand Up @@ -107,7 +105,8 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str,
configs['lfr_conf'] = {'lfr_m': 7, 'lfr_n': 6}

configs['input_dim'] = configs['lfr_conf']['lfr_m'] * 80
configs['predictor'] = 'cif_predictor'
# configs['predictor'] = 'cif_predictor'
configs['predictor'] = 'paraformer_predictor'
configs['predictor_conf'] = configs.pop('predictor_conf')
configs['predictor_conf']['cnn_groups'] = 1
configs['predictor_conf']['residual'] = False
Expand Down Expand Up @@ -162,10 +161,26 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str,
return configs


def convert_to_wenet_state_dict(args, configs, wenet_model_path):
args.checkpoint = args.paraformer_model
model, _ = init_model(args, configs)
save_checkpoint(model, wenet_model_path)
def convert_to_wenet_state_dict(args, wenet_model_path):
wenet_state_dict = {}
checkpoint = torch.load(args.paraformer_model, map_location='cpu')
for name in checkpoint.keys():
wenet_name = name

if wenet_name.startswith('predictor.cif_output2'):
wenet_name = wenet_name.replace('predictor.cif_output2.',
'predictor.tp_output.')
elif wenet_name.startswith('predictor.cif'):
wenet_name = wenet_name.replace('predictor.cif',
'predictor.predictor.cif')
elif wenet_name.startswith('predictor.upsample'):
wenet_name = wenet_name.replace('predictor.', 'predictor.tp_')
elif wenet_name.startswith('predictor.blstm'):
wenet_name = wenet_name.replace('predictor.', 'predictor.tp_')

wenet_state_dict[wenet_name] = checkpoint[name].float()

torch.save(wenet_state_dict, wenet_model_path)


def get_args():
Expand All @@ -190,11 +205,15 @@ def get_args():
return args


def _download_fn(output_dir, name, renmae: Optional[str] = None):
def _download_fn(output_dir,
name,
renmae: Optional[str] = None,
version: str = 'master'):
url = "https://www.modelscope.cn/api/v1/"\
"models/damo/"\
"speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"\
"/repo?Revision=v1.0.4&FilePath=" + name
"speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"\
"/repo?Revision={}&FilePath=".format(version) + name
# "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"\
if renmae is None:
output_file = os.path.join(output_dir, name)
else:
Expand Down Expand Up @@ -232,7 +251,7 @@ def may_get_assets_and_refine_args(args):
config_name = 'config.yaml'
args.paraformer_config = os.path.join(assets_dir, config_name)
if not os.path.exists(args.paraformer_config):
_download_fn(assets_dir, config_name)
_download_fn(assets_dir, config_name, version='v1.2.4')
if args.paraformer_cmvn is None:
cmvn_name = 'am.mvn'
args.paraformer_cmvn = os.path.join(assets_dir, cmvn_name)
Expand Down Expand Up @@ -280,11 +299,10 @@ def main():
'tokenizer_conf'
]
wenet_train_yaml = os.path.join(args.output_dir, "train.yaml")
wenet_configs = convert_to_wenet_yaml(configs, wenet_train_yaml,
fields_to_keep)
convert_to_wenet_yaml(configs, wenet_train_yaml, fields_to_keep)

wenet_model_path = os.path.join(args.output_dir, "wenet_paraformer.pt")
convert_to_wenet_state_dict(args, wenet_configs, wenet_model_path)
convert_to_wenet_state_dict(args, wenet_model_path)

print("Please check {} {} {} {} {} in {}".format(json_cmvn_path,
wenet_train_yaml,
Expand Down
Loading
Loading