Skip to content

Commit

Permalink
fix device
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 13, 2023
1 parent 1a9b73c commit d60fe79
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
5 changes: 5 additions & 0 deletions wenet/paraformer/convert_paraformer_to_wenet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def convert_to_wenet_yaml(configs, wenet_yaml_path: str, fields_to_keep: List):
configs['dataset_conf']['shuffle'] = False
configs['dataset_conf']['sort'] = False

configs['grad_clip'] = 5
configs['accum_grad'] = 1
configs['max_epoch'] = 100
configs['log_interval'] = 100

with open(wenet_yaml_path, '+w') as f:
f.write(yaml.dump(configs))
f.flush()
Expand Down
3 changes: 2 additions & 1 deletion wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens,
if target_num > 0:
input_mask[li].scatter_(
dim=0,
index=torch.randperm(ys_pad_lens[li])[:target_num],
index=torch.randperm(ys_pad_lens[li],
device=device)[:target_num],
value=0,
)
input_mask = torch.where(input_mask > 0, 1, 0)
Expand Down
6 changes: 3 additions & 3 deletions wenet/utils/init_ali_paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def init_model(configs, checkpoint_path=None):
encoder = SanmEncoder(global_cmvn=global_cmvn,
input_size=configs['lfr_conf']['lfr_m'] * input_dim,
**configs['encoder_conf'])
decoder = decoder = SanmDecoder(vocab_size=vocab_size,
encoder_output_size=encoder.output_size(),
**configs['decoder_conf'])
decoder = SanmDecoder(vocab_size=vocab_size,
encoder_output_size=encoder.output_size(),
**configs['decoder_conf'])
predictor = Cif(**configs['cif_predictor_conf'])
model = Paraformer(
vocab_size=vocab_size,
Expand Down

0 comments on commit d60fe79

Please sign in to comment.