From b35247b7df5cd578addf74bf0e1e35116ea20ce6 Mon Sep 17 00:00:00 2001 From: xingchensong Date: Thu, 18 Apr 2024 11:16:51 +0800 Subject: [PATCH 1/2] [transformer] fix warning: ignore(True) has been deprecated --- wenet/paraformer/layers.py | 4 ++-- wenet/transformer/decoder.py | 2 +- wenet/transformer/encoder.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/wenet/paraformer/layers.py b/wenet/paraformer/layers.py index cbd9f150e..b1cd362a0 100644 --- a/wenet/paraformer/layers.py +++ b/wenet/paraformer/layers.py @@ -282,7 +282,7 @@ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, @@ -471,7 +471,7 @@ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, x = layer(x) return x - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, x: torch.Tensor, tgt_mask: torch.Tensor, memory: torch.Tensor, diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index f674cea28..00df599f2 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -207,7 +207,7 @@ def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor, memory_mask) return x - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, x: torch.Tensor, tgt_mask: torch.Tensor, memory: torch.Tensor, diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 26cb0ef8c..83ea684fe 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -184,7 +184,7 @@ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) return xs - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward_layers_checkpointed(self, xs: torch.Tensor, chunk_masks: torch.Tensor, pos_emb: torch.Tensor, From b377fb001a0562e0f96bcdfe63d9341fd6920d0d Mon Sep 17 00:00:00 2001 From: xingchensong Date: Thu, 18 Apr 2024 11:22:39 +0800 Subject: [PATCH 2/2] [transformer] fix warning: ignore(True) has been deprecated --- wenet/ctl_model/asr_model_ctl.py | 2 +- wenet/k2/model.py | 6 +++--- wenet/paraformer/paraformer.py | 4 ++-- wenet/ssl/w2vbert/w2vbert_model.py | 2 +- wenet/ssl/wav2vec2/wav2vec2_model.py | 2 +- wenet/transformer/asr_model.py | 6 +++--- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/wenet/ctl_model/asr_model_ctl.py b/wenet/ctl_model/asr_model_ctl.py index c5457e590..6e9bc810a 100644 --- a/wenet/ctl_model/asr_model_ctl.py +++ b/wenet/ctl_model/asr_model_ctl.py @@ -67,7 +67,7 @@ def __init__( self.ctl_weight = ctl_weight self.logit_temp = logit_temp - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: dict, diff --git a/wenet/k2/model.py b/wenet/k2/model.py index bbc580cdc..d76d89dd7 100644 --- a/wenet/k2/model.py +++ b/wenet/k2/model.py @@ -54,7 +54,7 @@ def __init__( if self.lfmmi_dir != '': self.load_lfmmi_resource() - @torch.jit.ignore(drop=True) + @torch.jit.unused def _forward_ctc( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, text: torch.Tensor, @@ -63,7 +63,7 @@ def _forward_ctc( text) return loss_ctc, ctc_probs - @torch.jit.ignore(drop=True) + @torch.jit.unused def load_lfmmi_resource(self): try: import icefall @@ -94,7 +94,7 @@ def load_lfmmi_resource(self): assert len(arr) == 2 self.word_table[int(arr[1])] = arr[0] - @torch.jit.ignore(drop=True) + @torch.jit.unused def _calc_lfmmi_loss(self, encoder_out, encoder_mask, text): try: import k2 diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index 64b3587ec..be19f15b4 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -148,7 +148,7 @@ def __init__(self, # labels: 你 好 we@@ net eos self.add_eos = add_eos - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: Dict, @@ -232,7 +232,7 @@ def _calc_att_loss( ignore_label=self.ignore_id) return loss_att, acc_att - @torch.jit.ignore(drop=True) + @torch.jit.unused def _sampler(self, encoder_out, encoder_out_mask, ys_pad, ys_pad_lens, pre_acoustic_embeds): device = encoder_out.device diff --git a/wenet/ssl/w2vbert/w2vbert_model.py b/wenet/ssl/w2vbert/w2vbert_model.py index b87459529..27db0abf1 100644 --- a/wenet/ssl/w2vbert/w2vbert_model.py +++ b/wenet/ssl/w2vbert/w2vbert_model.py @@ -158,7 +158,7 @@ def _reset_parameter(module: torch.nn.Module): _reset_parameter(conv1) _reset_parameter(conv2) - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: Dict, diff --git a/wenet/ssl/wav2vec2/wav2vec2_model.py b/wenet/ssl/wav2vec2/wav2vec2_model.py index 69d5af022..9cbd0c3b3 100644 --- a/wenet/ssl/wav2vec2/wav2vec2_model.py +++ b/wenet/ssl/wav2vec2/wav2vec2_model.py @@ -217,7 +217,7 @@ def _reset_parameter(module: torch.nn.Module): _reset_parameter(conv1) _reset_parameter(conv2) - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: Dict, diff --git a/wenet/transformer/asr_model.py b/wenet/transformer/asr_model.py index f352271ab..4099947b8 100644 --- a/wenet/transformer/asr_model.py +++ b/wenet/transformer/asr_model.py @@ -74,7 +74,7 @@ def __init__( normalize_length=length_normalized_loss, ) - @torch.jit.ignore(drop=True) + @torch.jit.unused def forward( self, batch: dict, @@ -133,7 +133,7 @@ def forward( "th_accuracy": acc_att, } - @torch.jit.ignore(drop=True) + @torch.jit.unused def _forward_ctc( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, text: torch.Tensor, @@ -231,7 +231,7 @@ def _forward_encoder( ) # (B, maxlen, encoder_dim) return encoder_out, encoder_mask - @torch.jit.ignore(drop=True) + @torch.jit.unused def ctc_logprobs(self, encoder_out: torch.Tensor, blank_penalty: float = 0.0,