Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Nov 13, 2023
1 parent 5b6d758 commit f8fea10
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions wenet/paraformer/paraformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def __init__(self,
del self.decoder.embed
# NOTE(Mddct): add eos in tail of labels for predictor
# eg:
#. gt: 你 好 we@@ net
#. labels: 你 好 we@@ net eos
# gt: 你 好 we@@ net
# labels: 你 好 we@@ net eos
self.add_eos = add_eos

@torch.jit.ignore(drop=True)
Expand Down Expand Up @@ -165,7 +165,6 @@ def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens,
device = encoder_out.device
B, _ = ys_pad.size()

# update from: https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/models/e2e_asr_paraformer.py#L526
tgt_mask = make_non_pad_mask(ys_pad_lens)
ys_pad_embed = self.embed(ys_pad) # [B, T, L]
with torch.no_grad():
Expand Down Expand Up @@ -193,8 +192,8 @@ def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens,
input_mask = input_mask * tgt_mask
input_mask_expand = input_mask.unsqueeze(2) # [B, T, 1]

sematic_embeds = torch.where(input_mask_expand == 1, ys_pad_embed,
pre_acoustic_embeds)
sematic_embeds = torch.where(input_mask_expand == 1,
pre_acoustic_embeds, ys_pad_embed)
# zero out the paddings
return sematic_embeds * tgt_mask.unsqueeze(2)

Expand Down

0 comments on commit f8fea10

Please sign in to comment.