diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 95b697263..4892ddb4a 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -2,6 +2,7 @@ ## Conformer Result Bidecoder (large) +* Encoder FLOPs(30s): 96,238,430,720, params: 85,709,704 * Feature info: using fbank feature, cmvn, dither, online speed perturb * Training info: train_conformer_bidecoder_large.yaml, kernel size 31, lr 0.002, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 * Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 @@ -18,8 +19,31 @@ | LM-tglarge + attention rescoring | 2.68 | 6.10 | | LM-fglarge + attention rescoring | 2.65 | 5.98 | +## SqueezeFormer Result (U2++, FFN:2048) + +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv1d, batch_norm, syncbn + * encoder_dim 512, output_size 512, head 8, ffn_dim 512*4=2048 + * Encoder FLOPs(30s): 82,283,704,832, params: 85,984,648 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb, spec_aug +* Training info: + * train_squeezeformer_bidecoder_large.yaml, kernel size 31 + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 + * adamw, lr 8e-4, NoamHold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 2.62 | 6.80 | 2.92 | 6.77 | +| ctc prefix beam search | 2.60 | 6.79 | 2.90 | 6.79 | +| attention decoder | 3.06 | 6.90 | 3.38 | 6.82 | +| attention rescoring | 2.33 | 6.29 | 2.57 | 6.22 | + ## Conformer Result +* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 * Feature info: using fbank feature, cmvn, dither, online speed perturb * Training info: train_conformer.yaml, kernel size 31, lr 0.004, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 * Decoding info: ctc_weight 0.5, average_num 30 @@ -35,6 +59,82 @@ | attention rescoring (beam 50) | 3.12 | 8.55 | | LM-fglarge + attention rescoring | 3.09 | 7.40 | +## Conformer Result (12 layers, FFN:2048) +* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: train_squeezeformer.yaml, kernel size 31, +* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 +* AdamW, lr 1e-3, NoamHold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.49 | 9.59 | 3.66 | 9.59 | +| ctc prefix beam search | 3.49 | 9.61 | 3.66 | 9.55 | +| attention decoder | 3.52 | 9.04 | 3.85 | 8.97 | +| attention rescoring | 3.10 | 8.91 | 3.29 | 8.81 | + +## SqueezeFormer Result (SM12, FFN:1024) +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*4=1024 + * Encoder FLOPs(30s): 21,158,877,440, params: 22,219,912 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31, + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 + * adamw, lr=1e-3, noamhold, warmup=0.2, hold=0.3, lr_decay=1.0 +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 | +| ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 | +| attention decoder | 3.59 | 8.74 | 3.75 | 8.70 | +| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | + +## SqueezeFormer Result (SM12, FFN:2048) +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31 + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 + * adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.34 | 9.01 | 3.47 | 8.85 | +| ctc prefix beam search | 3.33 | 9.02 | 3.46 | 8.81 | +| attention decoder | 3.64 | 8.62 | 3.91 | 8.33 | +| attention rescoring | 2.89 | 8.34 | 3.10 | 8.03 | + +## SqueezeFormer Result (SM12, FFN:1312) +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv1d, w/o syncbn + * encoder_dim 328, output_size 256, head 4, ffn_dim 328*4=1312 + * encoder FLOPs(30s): 34,103,960,008, params: 35,678,352 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31, + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 + * adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.20 | 8.46 | 3.30 | 8.58 | +| ctc prefix beam search | 3.18 | 8.44 | 3.30 | 8.55 | +| attention decoder | 3.38 | 8.31 | 3.89 | 8.32 | +| attention rescoring | 2.81 | 7.86 | 2.96 | 7.91 | + ## Conformer U2++ Result * Feature info: using fbank feature, cmvn, no speed perturb, dither @@ -43,17 +143,48 @@ * Git hash: 65270043fc8c2476d1ab95e7c39f730017a670e0 test clean + | decoding mode | full | 16 | |--------------------------------|------|------| | ctc prefix beam search | 3.76 | 4.54 | | attention rescoring | 3.32 | 3.80 | test other + | decoding mode | full | 16 | |--------------------------------|-------|-------| | ctc prefix beam search | 9.50 | 11.52 | | attention rescoring | 8.67 | 10.38 | +## SqueezeFormer Result (U2++, FFN:2048) + +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv1d, layer_norm, do_rel_shift false + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * Encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31 + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 + * adamw, lr 1e-3, NoamHold, warmup 0.1, hold 0.4, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + +test clean + +| decoding mode | full | 16 | +|--------------------------------|------|------| +| ctc prefix beam search | 3.81 | 4.59 | +| attention rescoring | 3.36 | 3.93 | + +test other + +| decoding mode | full | 16 | +|--------------------------------|-------|-------| +| ctc prefix beam search | 9.12 | 11.17 | +| attention rescoring | 8.43 | 10.21 | + ## Conformer U2 Result * Feature info: using fbank feature, cmvn, speed perturb, dither @@ -65,6 +196,7 @@ test other * LM-fglarge: [4-gram.arpa.gz](http://www.openslr.org/resources/11/4-gram.arpa.gz) test clean + | decoding mode | full | 16 | |----------------------------------|------|------| | ctc prefix beam search | 4.26 | 5.00 | @@ -76,6 +208,7 @@ test clean | LM-fglarge + attention rescoring | 3.38 | 3.74 | test other + | decoding mode | full | 16 | |----------------------------------|-------|-------| | ctc prefix beam search | 10.87 | 12.87 | diff --git a/examples/librispeech/s0/conf/train_squeezeformer.yaml b/examples/librispeech/s0/conf/train_squeezeformer.yaml new file mode 100644 index 000000000..15dd2d33b --- /dev/null +++ b/examples/librispeech/s0/conf/train_squeezeformer.yaml @@ -0,0 +1,88 @@ +# network architecture +# encoder related +encoder: squeezeformer +encoder_conf: + encoder_dim: 256 + output_size: 256 # dimension of attention + attention_heads: 4 + num_blocks: 12 # the number of encoder blocks + reduce_idx: 5 + recover_idx: 11 + pos_enc_layer_type: 'rel_pos' + time_reduction_layer_type: 'conv1d' + feed_forward_expansion_factor: 4 + input_dropout_rate: 0.1 + feed_forward_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + cnn_module_kernel: 31 + cnn_norm_type: layer_norm + adaptive_scale: true + normalize_before: false + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# dataset related +dataset_conf: + filter_conf: + max_length: 2000 + min_length: 50 + token_max_length: 400 + token_min_length: 1 + min_output_input_ratio: 0.0005 + max_output_input_ratio: 0.1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 0.1 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 12 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 120 +log_interval: 100 + +optim: adamw +optim_conf: + lr: 1.e-3 + weight_decay: 4.e-5 + +scheduler: NoamHoldAnnealing +scheduler_conf: + warmup_ratio: 0.2 + hold_ratio: 0.3 + max_steps: 87960 + decay_rate: 1.0 + min_lr: 1.e-5 diff --git a/examples/librispeech/s0/conf/train_squeezeformer_bidecoder_large.yaml b/examples/librispeech/s0/conf/train_squeezeformer_bidecoder_large.yaml new file mode 100644 index 000000000..3bac09e6e --- /dev/null +++ b/examples/librispeech/s0/conf/train_squeezeformer_bidecoder_large.yaml @@ -0,0 +1,96 @@ +# network architecture +# encoder related +encoder: squeezeformer +encoder_conf: + encoder_dim: 512 + output_size: 512 # dimension of attention + attention_heads: 8 + num_blocks: 12 # the number of encoder blocks + reduce_idx: 5 + recover_idx: 11 + feed_forward_expansion_factor: 4 + input_dropout_rate: 0.1 + feed_forward_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + cnn_module_kernel: 31 + cnn_norm_type: batch_norm + adaptive_scale: true + normalize_before: false + +# decoder related +decoder: bitransformer +decoder_conf: + attention_heads: 8 + linear_units: 2048 + num_blocks: 3 + r_num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + reverse_weight: 0.3 + +# dataset related +dataset_conf: + syncbn: true + filter_conf: + max_length: 2000 + min_length: 50 + token_max_length: 400 + token_min_length: 1 + min_output_input_ratio: 0.0005 + max_output_input_ratio: 0.1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 3 + num_f_mask: 2 + max_t: 100 + max_f: 27 + max_w: 80 +# warp_for_time: true + spec_sub: true + spec_sub_conf: + num_t_sub: 3 + max_t: 30 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 12 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 500 +log_interval: 100 + +optim: adamw +optim_conf: + lr: 1.e-3 + weight_decay: 4.e-5 + +scheduler: NoamHoldAnnealing +scheduler_conf: + warmup_ratio: 0.2 + hold_ratio: 0.3 + max_steps: 87960 + decay_rate: 1.0 + min_lr: 1.e-5 + diff --git a/wenet/bin/train.py b/wenet/bin/train.py index f03404676..53f14e2af 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -31,7 +31,7 @@ load_trained_modules) from wenet.utils.executor import Executor from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols -from wenet.utils.scheduler import WarmupLR +from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing from wenet.utils.config import override_config from wenet.utils.init_model import init_model @@ -195,7 +195,7 @@ def main(): model = init_model(configs) print(model) num_params = sum(p.numel() for p in model.parameters()) - print('the number of model params: {}'.format(num_params)) + print('the number of model params: {:,d}'.format(num_params)) # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine @@ -243,8 +243,19 @@ def main(): device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) - optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) - scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) + if configs['optim'] == 'adam': + optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) + elif configs['optim'] == 'adamw': + optimizer = optim.AdamW(model.parameters(), **configs['optim_conf']) + else: + raise ValueError("unknown optimizer: " + configs['optim']) + if configs['scheduler'] == 'warmuplr': + scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) + elif configs['scheduler'] == 'NoamHoldAnnealing': + scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf']) + else: + raise ValueError("unknown scheduler: " + configs['scheduler']) + final_epoch = None configs['rank'] = args.rank configs['is_distributed'] = distributed diff --git a/wenet/squeezeformer/attention.py b/wenet/squeezeformer/attention.py new file mode 100644 index 000000000..141113f1b --- /dev/null +++ b/wenet/squeezeformer/attention.py @@ -0,0 +1,223 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2022 Ximalaya Inc. (Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-Head Attention layer definition.""" + +import math +import torch +import torch.nn as nn +from wenet.transformer.attention import MultiHeadedAttention +from typing import Tuple + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, n_head, n_feat, dropout_rate, + do_rel_shift=False, adaptive_scale=False, init_weights=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.do_rel_shift = do_rel_shift + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + self.adaptive_scale = adaptive_scale + if self.adaptive_scale: + self.ada_scale = nn.Parameter( + torch.ones([1, 1, n_feat]), requires_grad=True) + self.ada_bias = nn.Parameter( + torch.zeros([1, 1, n_feat]), requires_grad=True) + if init_weights: + self.init_weights() + + def init_weights(self): + input_max = (self.h * self.d_k) ** -0.5 + torch.nn.init.uniform_(self.linear_q.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_q.bias, -input_max, input_max) + torch.nn.init.uniform_(self.linear_k.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_k.bias, -input_max, input_max) + torch.nn.init.uniform_(self.linear_v.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_v.bias, -input_max, input_max) + torch.nn.init.uniform_(self.linear_pos.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_out.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_out.bias, -input_max, input_max) + + def rel_shift(self, x, zero_triu: bool = False): + """Compute relative positinal encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, size). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + torch.Tensor: Output tensor. + """ + + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward_attention( + self, value: torch.Tensor, scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be True? + # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the + # 1st chunk to ease the onnx export.] + # 2. pytorch training + if mask.size(2) > 0: # time2 > 0 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # For last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + # (batch, head, time1, time2) + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + if self.adaptive_scale: + query = self.ada_scale * query + self.ada_bias + key = self.ada_scale * key + self.ada_bias + value = self.ada_scale * value + self.ada_bias + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split( + cache, cache.size(-1) // 2, dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + if self.do_rel_shift: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/wenet/squeezeformer/conv2d.py b/wenet/squeezeformer/conv2d.py new file mode 100644 index 000000000..c23026339 --- /dev/null +++ b/wenet/squeezeformer/conv2d.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conv2d Module with Valid Padding""" + +import torch.nn.functional as F +from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional + + +class Conv2dValid(_ConvNd): + """ + Conv2d operator for VALID mode padding. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None, + valid_trigx: bool = False, + valid_trigy: bool = False + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super(Conv2dValid, self).__init__( + in_channels, out_channels, kernel_size_, + stride_, padding_, dilation_, False, _pair(0), + groups, bias, padding_mode, **factory_kwargs) + self.valid_trigx = valid_trigx + self.valid_trigy = valid_trigy + + def _conv_forward( + self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + validx, validy = 0, 0 + if self.valid_trigx: + validx = (input.size(-2) * (self.stride[-2] - 1) - 1 + + self.kernel_size[-2]) // 2 + if self.valid_trigy: + validy = (input.size(-1) * (self.stride[-1] - 1) - 1 + + self.kernel_size[-1]) // 2 + return F.conv2d(input, weight, bias, self.stride, + (validx, validy), self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) diff --git a/wenet/squeezeformer/convolution.py b/wenet/squeezeformer/convolution.py new file mode 100644 index 000000000..e3a4e7a75 --- /dev/null +++ b/wenet/squeezeformer/convolution.py @@ -0,0 +1,175 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2022 Ximalaya Inc. (authors: Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + +"""ConvolutionModule definition.""" + +from typing import Tuple + +import torch +from torch import nn +from typeguard import check_argument_types + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model.""" + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True, + adaptive_scale: bool = False, + init_weights: bool = False + ): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + assert check_argument_types() + super().__init__() + self.bias = bias + self.channels = channels + self.kernel_size = kernel_size + self.adaptive_scale = adaptive_scale + if self.adaptive_scale: + self.ada_scale = torch.nn.Parameter( + torch.ones([1, 1, channels]), requires_grad=True) + self.ada_bias = torch.nn.Parameter( + torch.zeros([1, 1, channels]), requires_grad=True) + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1d(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + if init_weights: + self.init_weights() + + def init_weights(self): + pw_max = self.channels ** -0.5 + dw_max = self.kernel_size ** -0.5 + torch.nn.init.uniform_(self.pointwise_conv1.weight.data, -pw_max, pw_max) + if self.bias: + torch.nn.init.uniform_(self.pointwise_conv1.bias.data, -pw_max, pw_max) + torch.nn.init.uniform_(self.depthwise_conv.weight.data, -dw_max, dw_max) + if self.bias: + torch.nn.init.uniform_(self.depthwise_conv.bias.data, -dw_max, dw_max) + torch.nn.init.uniform_(self.pointwise_conv2.weight.data, -pw_max, pw_max) + if self.bias: + torch.nn.init.uniform_(self.pointwise_conv2.bias.data, -pw_max, pw_max) + + def forward( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + if self.adaptive_scale: + x = self.ada_scale * x + self.ada_bias + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + x = torch.cat((cache, x), dim=2) + assert (x.size(2) > self.lorder) + new_cache = x[:, :, -self.lorder:] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + return x.transpose(1, 2), new_cache diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py new file mode 100644 index 000000000..6c7b3cd91 --- /dev/null +++ b/wenet/squeezeformer/encoder.py @@ -0,0 +1,388 @@ +# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer) +# Squeezeformer(https://github.com/upskyy/Squeezeformer) +# NeMo(https://github.com/NVIDIA/NeMo) + +import torch +import torch.nn as nn +from typing import Tuple +from wenet.squeezeformer.subsampling \ + import DepthwiseConv2dSubsampling4, TimeReductionLayer1D, TimeReductionLayer2D +from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer +from wenet.transformer.embedding import RelPositionalEncoding +from wenet.transformer.attention import MultiHeadedAttention +from wenet.squeezeformer.attention import RelPositionMultiHeadedAttention +from wenet.squeezeformer.positionwise_feed_forward \ + import PositionwiseFeedForward +from wenet.squeezeformer.convolution import ConvolutionModule +from wenet.utils.mask import make_pad_mask, add_optional_chunk_mask +from wenet.utils.common import get_activation + + +class SqueezeformerEncoder(nn.Module): + def __init__( + self, + input_size: int = 80, + encoder_dim: int = 256, + output_size: int = 256, + attention_heads: int = 4, + num_blocks: int = 12, + reduce_idx: int = 5, + recover_idx: int = 11, + feed_forward_expansion_factor: int = 4, + dw_stride: bool = False, + input_dropout_rate: float = 0.1, + pos_enc_layer_type: str = "rel_pos", + time_reduction_layer_type: str = "conv1d", + do_rel_shift: bool = True, + feed_forward_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.1, + cnn_module_kernel: int = 31, + cnn_norm_type: str = "batch_norm", + dropout: float = 0.1, + causal: bool = False, + adaptive_scale: bool = True, + activation_type: str = "swish", + init_weights: bool = True, + global_cmvn: torch.nn.Module = None, + normalize_before: bool = False, + use_dynamic_chunk: bool = False, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_left_chunk: bool = False + ): + """Construct SqueezeformerEncoder + + Args: + input_size to use_dynamic_chunk, see in Transformer BaseEncoder. + encoder_dim (int): The hidden dimension of encoder layer. + output_size (int): The output dimension of final projection layer. + attention_heads (int): Num of attention head in attention module. + num_blocks (int): Num of encoder layers. + reduce_idx (int): reduce layer index, from 40ms to 80ms per frame. + recover_idx (int): recover layer index, from 80ms to 40ms per frame. + feed_forward_expansion_factor (int): Enlarge coefficient of FFN. + dw_stride (bool): Whether do depthwise convolution + on subsampling module. + input_dropout_rate (float): Dropout rate of input projection layer. + pos_enc_layer_type (str): Self attention type. + time_reduction_layer_type (str): Conv1d or Conv2d reduction layer. + do_rel_shift (bool): Whether to do relative shift + operation on rel-attention module. + cnn_module_kernel (int): Kernel size of CNN module. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + adaptive_scale (bool): Whether to use adaptive scale. + init_weights (bool): Whether to initialize weights. + causal (bool): whether to use causal convolution or not. + """ + super(SqueezeformerEncoder, self).__init__() + self.global_cmvn = global_cmvn + self.reduce_idx = reduce_idx + self.recover_idx = recover_idx + self._output_size = output_size + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + activation = get_activation(activation_type) + + # self-attention module definition + if pos_enc_layer_type != "rel_pos": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + else: + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + encoder_dim, + attention_dropout_rate, + do_rel_shift, + adaptive_scale, + init_weights + ) + + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + encoder_dim, + encoder_dim * feed_forward_expansion_factor, + feed_forward_dropout_rate, + activation, + adaptive_scale, + init_weights + ) + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = ( + encoder_dim, cnn_module_kernel, activation, + cnn_norm_type, causal, True, adaptive_scale, init_weights) + + self.embed = DepthwiseConv2dSubsampling4( + 1, encoder_dim, + RelPositionalEncoding(encoder_dim, dropout_rate=0.1), + dw_stride + ) + self.input_proj = nn.Sequential( + nn.Linear( + encoder_dim * (((input_size - 1) // 2 - 1) // 2), encoder_dim), + nn.Dropout(p=input_dropout_rate), + ) + if init_weights: + linear_max = (encoder_dim * input_size / 4) ** -0.5 + torch.nn.init.uniform_( + self.input_proj.state_dict()['0.weight'], + -linear_max, linear_max) + torch.nn.init.uniform_( + self.input_proj.state_dict()['0.bias'], + -linear_max, linear_max) + self.preln = nn.LayerNorm(encoder_dim) + self.encoders = torch.nn.ModuleList([SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), + normalize_before, + dropout, + concat_after) for _ in range(num_blocks) + ]) + if time_reduction_layer_type == 'conv1d': + time_reduction_layer = TimeReductionLayer1D + time_reduction_layer_args = { + 'ichannel': encoder_dim, + 'ochannel': encoder_dim, + } + else: + time_reduction_layer = TimeReductionLayer2D + time_reduction_layer_args = {'encoder_dim': encoder_dim} + + self.time_reduction_layer = \ + time_reduction_layer(**time_reduction_layer_args) + self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim) + self.final_proj = None + if output_size != encoder_dim: + self.final_proj = nn.Linear(encoder_dim, output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + xs_lens = chunk_masks.squeeze(1).sum(1) + xs = self.input_proj(xs) + xs = self.preln(xs) + recover_tensor = torch.tensor(0.) + recover_chunk_masks = torch.tensor(0.) + recover_pos_emb = torch.tensor(0.) + recover_mask_pad = torch.tensor(0.) + for idx, layer in enumerate(self.encoders): + if idx == self.reduce_idx: + recover_tensor = xs + recover_chunk_masks = chunk_masks + recover_pos_emb = pos_emb + recover_mask_pad = mask_pad + xs, xs_lens, chunk_masks, mask_pad = \ + self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad) + pos_emb = pos_emb[:, :xs.size(1), :] + + if idx == self.recover_idx: + # recover output length for ctc decode + xs = torch.repeat_interleave(xs, repeats=2, dim=1) + xs = self.time_recover_layer(xs) + xs = recover_tensor + xs[:, :recover_tensor.size(1), :] + chunk_masks = recover_chunk_masks + pos_emb = recover_pos_emb + mask_pad = recover_mask_pad + + if self.reduce_idx <= idx < self.recover_idx: + residual = xs + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + xs += residual + else: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + + if self.final_proj is not None: + xs = self.final_proj(xs) + return xs, masks + + def forward_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.size(0), att_cache.size(2) + chunk_size = xs.size(1) + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding( + offset=offset - cache_t1, size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + xs, _, new_att_cache, new_cnn_cache = layer( + xs, att_mask, pos_emb, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache + ) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) + if self.normalize_before: + xs = self.after_norm(xs) + + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = torch.cat(r_att_cache, dim=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = torch.cat(r_cnn_cache, dim=0) + + return (xs, r_att_cache, r_cnn_cache) + + def forward_chunk_by_chunk( + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not prefered. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + # The model is trained by static or dynamic chunk + assert self.static_chunk_size > 0 or self.use_dynamic_chunk + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, att_cache, cnn_cache) = self.forward_chunk( + chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + outputs.append(y) + offset += y.size(1) + ys = torch.cat(outputs, 1) + masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) + return ys, masks diff --git a/wenet/squeezeformer/encoder_layer.py b/wenet/squeezeformer/encoder_layer.py new file mode 100644 index 000000000..3c6bdd44a --- /dev/null +++ b/wenet/squeezeformer/encoder_layer.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SqueezeformerEncoderLayer definition.""" + +import torch +import torch.nn as nn +from typing import Optional, Tuple + + +class SqueezeformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward1 (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + feed_forward2 (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward1: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + feed_forward2: Optional[nn.Module] = None, + normalize_before: bool = False, + dropout_rate: float = 0.1, + concat_after: bool = False, + ): + super(SqueezeformerEncoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.layer_norm1 = nn.LayerNorm(size) + self.ffn1 = feed_forward1 + self.layer_norm2 = nn.LayerNorm(size) + self.conv_module = conv_module + self.layer_norm3 = nn.LayerNorm(size) + self.ffn2 = feed_forward2 + self.layer_norm4 = nn.LayerNorm(size) + self.normalize_before = normalize_before + self.dropout = nn.Dropout(dropout_rate) + self.concat_after = concat_after + if concat_after: + self.concat_linear = nn.Linear(size + size, size) + else: + self.concat_linear = nn.Identity() + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # self attention module + residual = x + if self.normalize_before: + x = self.layer_norm1(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.layer_norm1(x) + + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm2(x) + x = self.ffn1(x) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm2(x) + + # conv module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + residual = x + if self.normalize_before: + x = self.layer_norm3(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm3(x) + + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm4(x) + x = self.ffn2(x) + # we do not use dropout here since it is inside feed forward function + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm4(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/wenet/squeezeformer/positionwise_feed_forward.py b/wenet/squeezeformer/positionwise_feed_forward.py new file mode 100644 index 000000000..e02d2f0dc --- /dev/null +++ b/wenet/squeezeformer/positionwise_feed_forward.py @@ -0,0 +1,78 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Ximalaya Inc (Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__(self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + adaptive_scale: bool = False, + init_weights: bool = False + ): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.idim = idim + self.hidden_units = hidden_units + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = torch.nn.Dropout(dropout_rate) + self.w_2 = torch.nn.Linear(hidden_units, idim) + self.adaptive_scale = adaptive_scale + if self.adaptive_scale: + self.ada_scale = torch.nn.Parameter( + torch.ones([1, 1, idim]), requires_grad=True) + self.ada_bias = torch.nn.Parameter( + torch.zeros([1, 1, idim]), requires_grad=True) + if init_weights: + self.init_weights() + + def init_weights(self): + ffn1_max = self.idim ** -0.5 + ffn2_max = self.hidden_units ** -0.5 + torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max) + torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max) + torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max) + torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + if self.adaptive_scale: + xs = self.ada_scale * xs + self.ada_bias + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py new file mode 100644 index 000000000..d9cd89973 --- /dev/null +++ b/wenet/squeezeformer/subsampling.py @@ -0,0 +1,208 @@ +# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer) +# Squeezeformer(https://github.com/upskyy/Squeezeformer) +# NeMo(https://github.com/NVIDIA/NeMo) + +"""DepthwiseConv2dSubsampling4 and TimeReductionLayer definition.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from wenet.transformer.subsampling import BaseSubsampling +from typing import Tuple +from wenet.squeezeformer.conv2d import Conv2dValid + + +class DepthwiseConv2dSubsampling4(BaseSubsampling): + """Depthwise Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + pos_enc_class (nn.Module): position encoding class. + + """ + + def __init__( + self, idim: int, odim: int, + pos_enc_class: torch.nn.Module, + dw_stride: bool = False + ): + super(DepthwiseConv2dSubsampling4, self).__init__() + self.idim = idim + self.odim = odim + self.pw_conv = nn.Conv2d( + in_channels=idim, out_channels=odim, kernel_size=3, stride=2) + self.act1 = nn.ReLU() + self.dw_conv = nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2, + groups=odim if dw_stride else 1 + ) + self.act2 = nn.ReLU() + self.pos_enc = pos_enc_class + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.pw_conv(x) + x = self.act1(x) + x = self.dw_conv(x) + x = self.act2(x) + b, c, t, f = x.size() + x = x.permute(0, 2, 1, 3) + x = x.contiguous().view(b, t, c * f) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] + + +class TimeReductionLayer1D(nn.Module): + """ + Modified NeMo, + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of + MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, channel: int, out_dim: int, + kernel_size: int = 5, stride: int = 2): + super(TimeReductionLayer1D, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + self.padding = max(0, self.kernel_size - self.stride) + + self.dw_conv = nn.Conv1d( + in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + groups=channel, + ) + + self.pw_conv = nn.Conv1d( + in_channels=channel, out_channels=out_dim, + kernel_size=1, stride=1, padding=0, groups=1, + ) + + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.channel ** -0.5 + torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) + + def forward(self, xs, xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ): + xs = xs.transpose(1, 2) # [B, C, T] + xs = xs.masked_fill(mask_pad.eq(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose(1, 2) # [B, T, C] + + B, T, D = xs.size() + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.size(-1) + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :].contiguous() + else: + dummy_pad = torch.zeros(B, L - T, D, device=xs.device) + xs = torch.cat([xs, dummy_pad], dim=1) + + xs_lens = torch.div(xs_lens + 1, 2, rounding_mode='trunc') + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayer2D(nn.Module): + def __init__( + self, ichannel: int = 1, ochannel: int = 1, + kernel_size: int = 5, stride: int = 2, encoder_dim: int = 256): + super(TimeReductionLayer2D, self).__init__() + self.ichannel = ichannel + self.kernel_size = kernel_size + self.dw_conv = Conv2dValid( + in_channels=ichannel, + out_channels=ochannel, + kernel_size=(kernel_size, 1), + stride=stride, + valid_trigy=True + ) + self.pw_conv = Conv2dValid( + in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=1, + stride=1, + valid_trigx=False, + valid_trigy=False, + ) + + self.in_channels = ichannel + self.kernel_size = kernel_size + self.stride = stride + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.ichannel ** -0.5 + torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) + + def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + xs = xs.masked_fill(mask_pad.transpose(1, 2).eq(0), 0.0) + xs = xs.unsqueeze(2) + padding1 = self.kernel_size - self.stride + xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), + mode='constant', value=0.) + xs = self.dw_conv(xs.transpose(1, 2)) + xs = xs.permute(0, 3, 1, 2).contiguous() + xs = self.pw_conv(xs).permute(0, 2, 3, 1).squeeze(1).contiguous() + tmp_length = xs.size(1) + xs_lens = torch.div(xs_lens + 1, 2, rounding_mode='trunc') + padding2 = max(0, (xs_lens.max() - tmp_length).data.item()) + batch_size, hidden = xs.size(0), xs.size(-1) + dummy_pad = torch.zeros(batch_size, padding2, hidden, device=xs.device) + xs = torch.cat([xs, dummy_pad], dim=1) + mask = mask[:, ::2, ::2] + mask_pad = mask_pad[:, :, ::2] + return xs, xs_lens, mask, mask_pad diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 1e90c7b18..b739c82be 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -22,6 +22,7 @@ from wenet.transformer.ctc import CTC from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder +from wenet.squeezeformer.encoder import SqueezeformerEncoder from wenet.utils.cmvn import load_cmvn @@ -44,6 +45,10 @@ def init_model(configs): encoder = ConformerEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'squeezeformer': + encoder = SqueezeformerEncoder(input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf']) else: encoder = TransformerEncoder(input_dim, global_cmvn=global_cmvn, diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index e4705cdf4..383096cbb 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -1,4 +1,5 @@ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Ximalaya Inc (Yuguang Yang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) +# NeMo(https://github.com/NVIDIA/NeMo) from typing import Union +import math +import warnings import torch from torch.optim.lr_scheduler import _LRScheduler @@ -39,10 +43,10 @@ class WarmupLR(_LRScheduler): """ def __init__( - self, - optimizer: torch.optim.Optimizer, - warmup_steps: Union[int, float] = 25000, - last_epoch: int = -1, + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 25000, + last_epoch: int = -1, ): assert check_argument_types() self.warmup_steps = warmup_steps @@ -71,3 +75,596 @@ def get_lr(self): def set_step(self, step: int): self.last_epoch = step + + +class WarmupPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, + max_steps=None, min_lr=0.0, last_epoch=-1): + assert not (warmup_steps is not None and warmup_ratio is not None),\ + "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning + ) + + step = self.last_epoch + + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class SquareRootConstantPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__( + self, optimizer, *, constant_steps=None, constant_ratio=None, + max_steps=None, min_lr=0.0, last_epoch=-1 + ): + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use particular number of step or ratio" + assert constant_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.constant_lr = 1 / (constant_steps ** 0.5) + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning + ) + + step = self.last_epoch + + if step <= self.constant_steps: + return [self.constant_lr for _ in self.base_lrs] + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class WarmupHoldPolicy(WarmupPolicy): + """Variant of WarmupPolicy which maintains high + learning rate for a defined number of steps. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to + hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + hold_steps=None, + hold_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (hold_steps is not None and hold_ratio is not None), \ + "Either use particular number of step or ratio" + assert hold_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + self.min_lr = min_lr + self._last_warmup_lr = 0.0 + + # Necessary to duplicate as class attributes are hidden in inner class + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if hold_steps is not None: + self.hold_steps = hold_steps + self.warmup_steps + elif hold_ratio is not None: + self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps + else: + self.hold_steps = 0 + + super().__init__( + optimizer, + warmup_steps=warmup_steps, + warmup_ratio=warmup_ratio, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + ) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler," + " " "please use `get_last_lr()`.", + UserWarning + ) + + step = self.last_epoch + + # Warmup phase + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + # Hold phase + if (step >= self.warmup_steps) and (step < self.hold_steps): + return self.base_lrs + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + +class WarmupAnnealHoldPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + min_lr: Minimum lr to hold the learning rate after decay at. + constant_steps: Number of steps to keep lr constant at. + constant_ratio: Ratio of steps to keep lr constant. + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (warmup_steps is not None + and warmup_ratio is not None), \ + "Either use particular number of step or ratio" + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use constant_steps or constant_ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps) + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning + ) + + step = self.last_epoch + + # Warmup steps + if self.warmup_steps > 0 and step <= self.warmup_steps: + return self._get_warmup_lr(step) + + # Constant steps after warmup and decay + if self.constant_steps > 0 and ( + self.warmup_steps + self.decay_steps) < step <= self.max_steps: + return self._get_constant_lr(step) + + # Min lr after max steps of updates + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_constant_lr(self, step): + return [self.min_lr for _ in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +def _squareroot_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps) ** 0.5 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _square_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps) ** 2 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _cosine_annealing(initial_lr, step, max_steps, min_lr): + mult = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + out_lr = (initial_lr - min_lr) * mult + min_lr + return out_lr + + +def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, + decay_steps, min_lr): + assert max_lr > min_lr + # Use linear warmup for the initial part. + if warmup_steps > 0 and step <= warmup_steps: + return max_lr * float(step) / float(warmup_steps) + + # For any steps larger than `decay_steps`, use `min_lr`. + if step > warmup_steps + decay_steps: + return min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = step - warmup_steps + decay_steps_ = decay_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = max_lr - min_lr + + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + + return min_lr + coeff * delta_lr + + +def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): + if cycle: + multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps) + decay_steps *= multiplier + else: + step = min(step, decay_steps) + p = step / decay_steps + lr = (initial_lr - min_lr) * math.pow(1.0 - p, power) + lr += min_lr + return lr + + +def _noam_hold_annealing(initial_lr, step, warmup_steps, + hold_steps, decay_rate, min_lr): + # hold_steps = total number of steps + # to hold the LR, not the warmup + hold steps. + T_warmup_decay = max(1, warmup_steps ** decay_rate) + T_hold_decay = max(1, (step - hold_steps) ** decay_rate) + lr = (initial_lr * T_warmup_decay) / T_hold_decay + lr = max(lr, min_lr) + return lr + + +class SquareAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, + last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _square_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class SquareRootAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, + last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _squareroot_annealing(initial_lr=initial_lr, step=step, + max_steps=self.max_steps, min_lr=self.min_lr) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class CosineAnnealing(WarmupAnnealHoldPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, + last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate." + ) + + if self.constant_steps is None or self.constant_steps == 0: + new_lrs = [ + _cosine_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + else: + new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step) + return new_lrs + + def _get_warmup_lr(self, step): + if self.constant_steps is None or self.constant_steps == 0: + return super()._get_warmup_lr(step) + else: + # Use linear warmup for the initial part. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_constant_lr(self, step): + # Only called when `constant_steps` > 0. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_linear_warmup_with_cosine_annealing_lr(self, step): + # Cosine Schedule for Megatron LM, + # slightly different warmup schedule + constant LR at the end. + new_lrs = [ + _linear_warmup_with_cosine_annealing( + max_lr=self.base_lrs[0], + warmup_steps=self.warmup_steps, + step=step, + decay_steps=self.decay_steps, + min_lr=self.min_lr, + ) + for _ in self.base_lrs + ] + return new_lrs + + +class NoamAnnealing(_LRScheduler): + def __init__( + self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, + max_steps=None, min_lr=0.0, last_epoch=-1 + ): + self._normalize = d_model ** (-0.5) + assert not (warmup_steps is not None + and warmup_ratio is not None), \ + "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning + ) + + step = max(1, self.last_epoch) + + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate." + ) + + new_lrs = [self._noam_annealing(initial_lr=initial_lr, step=step) for + initial_lr in self.base_lrs] + return new_lrs + + def _noam_annealing(self, initial_lr, step): + if self.warmup_steps > 0: + mult = self._normalize * min(step ** (-0.5), + step * (self.warmup_steps ** (-1.5))) + else: + mult = self._normalize * step ** (-0.5) + + out_lr = initial_lr * mult + if step > self.warmup_steps: + out_lr = max(out_lr, self.min_lr) + return out_lr + + +class NoamHoldAnnealing(WarmupHoldPolicy): + def __init__(self, optimizer, *, max_steps, decay_rate=0.5, min_lr=0.0, + last_epoch=-1, **kwargs): + """ + From Nemo: + Implementation of the Noam Hold Annealing policy + from the SqueezeFormer paper. + + Unlike NoamAnnealing, the peak learning rate + can be explicitly set for this scheduler. + The schedule first performs linear warmup, + then holds the peak LR, then decays with some schedule for + the remainder of the steps. + Therefore the min-lr is still dependent + on the hyper parameters selected. + + It's schedule is determined by three factors- + + Warmup Steps: Initial stage, where linear warmup + occurs uptil the peak LR is reached. Unlike NoamAnnealing, + the peak LR is explicitly stated here instead of a scaling factor. + + Hold Steps: Intermediate stage, where the peak LR + is maintained for some number of steps. In this region, + the high peak LR allows the model to converge faster + if training is stable. However the high LR + may also cause instability during training. + Should usually be a significant fraction of training + steps (around 30-40% of the entire training steps). + + Decay Steps: Final stage, where the LR rapidly decays + with some scaling rate (set by decay rate). + To attain Noam decay, use 0.5, + for Squeezeformer recommended decay, use 1.0. + The fast decay after prolonged high LR during + hold phase allows for rapid convergence. + + References: + - [Squeezeformer: + An Efficient Transformer for Automatic Speech Recognition] + (https://arxiv.org/abs/2206.00888) + + Args: + optimizer: Pytorch compatible Optimizer object. + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to + hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + decay_rate: Float value describing the polynomial decay + after the hold period. Default value + of 0.5 corresponds to Noam decay. + min_lr: Minimum learning rate. + """ + self.decay_rate = decay_rate + super().__init__(optimizer=optimizer, max_steps=max_steps, + last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + if self.warmup_steps is None or self.warmup_steps == 0: + raise ValueError( + "Noam scheduler cannot be used without warmup steps") + + if self.hold_steps > 0: + hold_steps = self.hold_steps - self.warmup_steps + else: + hold_steps = 0 + + new_lrs = [ + _noam_hold_annealing( + initial_lr, + step=step, + warmup_steps=self.warmup_steps, + hold_steps=hold_steps, + decay_rate=self.decay_rate, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + return new_lrs + + def set_step(self, step: int): + self.last_epoch = step