diff --git a/wenet/paraformer/convert_paraformer_to_wenet_config.py b/wenet/paraformer/convert_paraformer_to_wenet_config.py index cce25f783..1c9045603 100644 --- a/wenet/paraformer/convert_paraformer_to_wenet_config.py +++ b/wenet/paraformer/convert_paraformer_to_wenet_config.py @@ -99,6 +99,11 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str, fields_to_keep: List): configs['dataset_conf']['shuffle'] = False configs['dataset_conf']['sort'] = False + configs['grad_clip'] = 5 + configs['accum_grad'] = 1 + configs['max_epoch'] = 100 + configs['log_interval'] = 100 + with open(wenet_yaml_path, '+w') as f: f.write(yaml.dump(configs)) f.flush() diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index 47e148f80..ebca77c32 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -183,7 +183,8 @@ def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens, if target_num > 0: input_mask[li].scatter_( dim=0, - index=torch.randperm(ys_pad_lens[li])[:target_num], + index=torch.randperm(ys_pad_lens[li], + device=device)[:target_num], value=0, ) input_mask = torch.where(input_mask > 0, 1, 0) diff --git a/wenet/utils/init_ali_paraformer.py b/wenet/utils/init_ali_paraformer.py index 13af51757..c0db11357 100644 --- a/wenet/utils/init_ali_paraformer.py +++ b/wenet/utils/init_ali_paraformer.py @@ -17,9 +17,9 @@ def init_model(configs, checkpoint_path=None): encoder = SanmEncoder(global_cmvn=global_cmvn, input_size=configs['lfr_conf']['lfr_m'] * input_dim, **configs['encoder_conf']) - decoder = decoder = SanmDecoder(vocab_size=vocab_size, - encoder_output_size=encoder.output_size(), - **configs['decoder_conf']) + decoder = SanmDecoder(vocab_size=vocab_size, + encoder_output_size=encoder.output_size(), + **configs['decoder_conf']) predictor = Cif(**configs['cif_predictor_conf']) model = Paraformer( vocab_size=vocab_size,