From 23c4f99ea4d1fbcbbca967ef90ad9a397aa8d555 Mon Sep 17 00:00:00 2001 From: xmly Date: Thu, 15 Sep 2022 12:56:51 +0800 Subject: [PATCH 01/29] [init] enable SqueezeformerEncoder --- wenet/squeezeformer/attention.py | 200 ++++++++++++++++ wenet/squeezeformer/conv2d.py | 47 ++++ wenet/squeezeformer/convolution.py | 172 ++++++++++++++ wenet/squeezeformer/encoder.py | 221 ++++++++++++++++++ wenet/squeezeformer/encoder_layer.py | 107 +++++++++ .../positionwise_feed_forward.py | 75 ++++++ wenet/squeezeformer/subsampling.py | 100 ++++++++ wenet/squeezeformer/utils.py | 25 ++ 8 files changed, 947 insertions(+) create mode 100644 wenet/squeezeformer/attention.py create mode 100644 wenet/squeezeformer/conv2d.py create mode 100644 wenet/squeezeformer/convolution.py create mode 100644 wenet/squeezeformer/encoder.py create mode 100644 wenet/squeezeformer/encoder_layer.py create mode 100644 wenet/squeezeformer/positionwise_feed_forward.py create mode 100644 wenet/squeezeformer/subsampling.py create mode 100644 wenet/squeezeformer/utils.py diff --git a/wenet/squeezeformer/attention.py b/wenet/squeezeformer/attention.py new file mode 100644 index 000000000..e7e06ee13 --- /dev/null +++ b/wenet/squeezeformer/attention.py @@ -0,0 +1,200 @@ +import math +import torch +import torch.nn as nn +from wenet.transformer.attention import MultiHeadedAttention +from typing import Tuple + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + """ + + def __init__(self, n_head, n_feat, dropout_rate, do_rel_shift=False, adaptive_scale=False, init_weights=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.do_rel_shift = do_rel_shift + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + self.adaptive_scale = adaptive_scale + if self.adaptive_scale: + self.scale = nn.Parameter(torch.tensor(1.), requires_grad=True) + self.bias = nn.Parameter(torch.tensor(0.), requires_grad=True) + if init_weights: + self.init_weights() + + def init_weights(self): + input_max = (self.h * self.d_k) ** -0.5 + torch.nn.init.uniform_(self.linear_q.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_q.bias, -input_max, input_max) + torch.nn.init.uniform_(self.linear_k.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_k.bias, -input_max, input_max) + torch.nn.init.uniform_(self.linear_v.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_v.bias, -input_max, input_max) + torch.nn.init.uniform_(self.linear_pos.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_out.weight, -input_max, input_max) + torch.nn.init.uniform_(self.linear_out.bias, -input_max, input_max) + + def rel_shift(self, x, zero_triu: bool = False): + """Compute relative positinal encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, size). + zero_triu (bool): If true, return the lower triangular part of + the matrix. + Returns: + torch.Tensor: Output tensor. + """ + + zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), + device=x.device, + dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(x.size()[0], + x.size()[1], + x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward_attention( + self, value: torch.Tensor, scores: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) + ) -> torch.Tensor: + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value, size + (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score, size + (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask, size (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + # NOTE(xcsong): When will `if mask.size(2) > 0` be True? + # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the + # 1st chunk to ease the onnx export.] + # 2. pytorch training + if mask.size(2) > 0: # time2 > 0 + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + # For last chunk, time2 might be larger than scores.size(-1) + mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) + scores = scores.masked_fill(mask, -float('inf')) + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + # NOTE(xcsong): When will `if mask.size(2) > 0` be False? + # 1. onnx(16/-1, -1/-1, 16/0) + # 2. jit (16/-1, -1/-1, 16/0, 16/4) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = (x.transpose(1, 2).contiguous().view(n_batch, -1, + self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query: torch.Tensor, + key: torch.Tensor, value: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + pos_emb: torch.Tensor = torch.empty(0), + cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2), (0, 0, 0) means fake mask. + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, time2, size). + cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) + where `cache_t == chunk_size * num_decoding_left_chunks` + and `head * d_k == size` + """ + if self.adaptive_scale: + query = self.scale * query + self.bias + key = self.scale * key + self.bias + value = self.scale * value + self.bias + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + # NOTE(xcsong): + # when export onnx model, for 1st chunk, we feed + # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) + # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). + # In all modes, `if cache.size(0) > 0` will alwayse be `True` + # and we will always do splitting and + # concatnation(this will simplify onnx export). Note that + # it's OK to concat & split zero-shaped tensors(see code below). + # when export jit model, for 1st chunk, we always feed + # cache(0, 0, 0, 0) since jit supports dynamic if-branch. + # >>> a = torch.ones((1, 2, 0, 4)) + # >>> b = torch.ones((1, 2, 3, 4)) + # >>> c = torch.cat((a, b), dim=2) + # >>> torch.equal(b, c) # True + # >>> d = torch.split(a, 2, dim=-1) + # >>> torch.equal(d[0], d[1]) # True + if cache.size(0) > 0: + key_cache, value_cache = torch.split( + cache, cache.size(-1) // 2, dim=-1) + k = torch.cat([key_cache, k], dim=2) + v = torch.cat([value_cache, v], dim=2) + # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's + # non-trivial to calculate `next_cache_start` here. + new_cache = torch.cat((k, v), dim=-1) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + # Remove rel_shift since it is useless in speech recognition, + # and it requires special attention for streaming. + if self.do_rel_shift: + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask), new_cache diff --git a/wenet/squeezeformer/conv2d.py b/wenet/squeezeformer/conv2d.py new file mode 100644 index 000000000..0cc545b33 --- /dev/null +++ b/wenet/squeezeformer/conv2d.py @@ -0,0 +1,47 @@ +import torch +import torch.nn.functional as F +from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional + + +class Conv2dValid(_ConvNd): + """ + Conv2d operator for VALID mode padding. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: _size_2_t, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = 'zeros', # TODO: refine this type + device=None, + dtype=None, + valid_trigx: bool = False, + valid_trigy: bool = False + ) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + kernel_size_ = _pair(kernel_size) + stride_ = _pair(stride) + padding_ = padding if isinstance(padding, str) else _pair(padding) + dilation_ = _pair(dilation) + super(Conv2dValid, self).__init__( + in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, + False, _pair(0), groups, bias, padding_mode, **factory_kwargs) + self.valid_trigx = valid_trigx + self.valid_trigy = valid_trigy + + def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + validx, validy = 0, 0 + if self.valid_trigx: + validx = (input.size(-2) * (self.stride[-2] - 1) - 1 + self.kernel_size[-2]) // 2 + if self.valid_trigy: + validy = (input.size(-1) * (self.stride[-1] - 1) - 1 + self.kernel_size[-1]) // 2 + return F.conv2d(input, weight, bias, self.stride, + (validx, validy), self.dilation, self.groups) + + def forward(self, input: Tensor) -> Tensor: + return self._conv_forward(input, self.weight, self.bias) \ No newline at end of file diff --git a/wenet/squeezeformer/convolution.py b/wenet/squeezeformer/convolution.py new file mode 100644 index 000000000..055540954 --- /dev/null +++ b/wenet/squeezeformer/convolution.py @@ -0,0 +1,172 @@ +# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from ESPnet(https://github.com/espnet/espnet) + +"""ConvolutionModule definition.""" + +from typing import Tuple + +import torch +from torch import nn +from typeguard import check_argument_types + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model.""" + def __init__(self, + channels: int, + kernel_size: int = 15, + activation: nn.Module = nn.ReLU(), + norm: str = "batch_norm", + causal: bool = False, + bias: bool = True, + adaptive_scale: bool = False, + init_weights: bool = False + ): + """Construct an ConvolutionModule object. + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernel size of conv layers. + causal (int): Whether use causal convolution or not + """ + assert check_argument_types() + super().__init__() + self.bias = bias + self.channels = channels + self.kernel_size = kernel_size + self.adaptive_scale = adaptive_scale + if self.adaptive_scale: + self.scale = torch.nn.Parameter(torch.tensor(1.), requires_grad=True) + self.bias = torch.nn.Parameter(torch.tensor(0.), requires_grad=True) + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # self.lorder is used to distinguish if it's a causal convolution, + # if self.lorder > 0: it's a causal convolution, the input will be + # padded with self.lorder frames on the left in forward. + # else: it's a symmetrical convolution + if causal: + padding = 0 + self.lorder = kernel_size - 1 + else: + # kernel_size should be an odd number for none causal convolution + assert (kernel_size - 1) % 2 == 0 + padding = (kernel_size - 1) // 2 + self.lorder = 0 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + + assert norm in ['batch_norm', 'layer_norm'] + if norm == "batch_norm": + self.use_layer_norm = False + self.norm = nn.BatchNorm1d(channels) + else: + self.use_layer_norm = True + self.norm = nn.LayerNorm(channels) + + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + if init_weights: + self.init_weights() + + def init_weights(self): + pw_max = self.channels ** -0.5 + dw_max = self.kernel_size ** -0.5 + torch.nn.init.uniform_(self.pointwise_conv1.weight.data, -pw_max, pw_max) + if self.bias: + torch.nn.init.uniform_(self.pointwise_conv1.bias.data, -pw_max, pw_max) + torch.nn.init.uniform_(self.depthwise_conv.weight.data, -dw_max, dw_max) + if self.bias: + torch.nn.init.uniform_(self.depthwise_conv.bias.data, -dw_max, dw_max) + torch.nn.init.uniform_(self.pointwise_conv2.weight.data, -pw_max, pw_max) + if self.bias: + torch.nn.init.uniform_(self.pointwise_conv2.bias.data, -pw_max, pw_max) + + def forward( + self, + x: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + cache: torch.Tensor = torch.zeros((0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute convolution module. + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), + (0, 0, 0) means fake mask. + cache (torch.Tensor): left context cache, it is only + used in causal convolution (#batch, channels, cache_t), + (0, 0, 0) meas fake cache. + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + """ + if self.adaptive_scale: + x = self.scale * x + self.bias + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) # (#batch, channels, time) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + if self.lorder > 0: + if cache.size(2) == 0: # cache_t == 0 + x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) + else: + assert cache.size(0) == x.size(0) # equal batch + assert cache.size(1) == x.size(1) # equal channel + x = torch.cat((cache, x), dim=2) + assert (x.size(2) > self.lorder) + new_cache = x[:, :, -self.lorder:] + else: + # It's better we just return None if no cache is required, + # However, for JIT export, here we just fake one tensor instead of + # None. + new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.activation(self.norm(x)) + if self.use_layer_norm: + x = x.transpose(1, 2) + x = self.pointwise_conv2(x) + # mask batch padding + if mask_pad.size(2) > 0: # time > 0 + x.masked_fill_(~mask_pad, 0.0) + + return x.transpose(1, 2), new_cache \ No newline at end of file diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py new file mode 100644 index 000000000..1c23a92b9 --- /dev/null +++ b/wenet/squeezeformer/encoder.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +from typing import Tuple +from wenet.squeezeformer.utils import ResidualModule +from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4, TimeReductionLayer +from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer +from wenet.transformer.embedding import RelPositionalEncoding +from wenet.transformer.attention import MultiHeadedAttention +from wenet.squeezeformer.attention import RelPositionMultiHeadedAttention +from wenet.squeezeformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.squeezeformer.convolution import ConvolutionModule +from wenet.utils.mask import make_pad_mask, add_optional_chunk_mask +from wenet.utils.common import get_activation + + +class SqueezeformerEncoder(nn.Module): + def __init__( + self, + input_size: int = 80, + encoder_dim: int = 256, + output_size: int = 256, + attention_heads: int = 4, + num_blocks: int = 12, + reduce_idx: int = 5, + recover_idx: int = 11, + feed_forward_expansion_factor: int = 4, + input_dropout_rate: float = 0.1, + pos_enc_layer_type: str = "rel_pos", + do_rel_shift: bool = True, + feed_forward_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.1, + cnn_module_kernel: int = 31, + cnn_norm_type: str = "batch_norm", + dropout: float = 0.1, + causal: bool = False, + adaptive_scale: bool = True, + activation_type: str = "swish", + init_weights: bool = True, + global_cmvn: torch.nn.Module = None, + normalize_before: bool = False, + use_dynamic_chunk: bool = False, + concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_left_chunk: bool = False + ): + """Construct SqueezeformerEncoder + + Args: + input_size to use_dynamic_chunk, see in Transformer BaseEncoder. + encoder_dim (int): The hidden dimension of encoder layer. + output_size (int): The output dimension of final projection layer. + attention_heads (int): Num of attention head in attention module. + num_blocks (int): Num of encoder layers. + reduce_idx (int): reduce layer index, from 40ms to 80ms per frame. + recover_idx (int): recover layer index, from 80ms to 40ms per frame. + feed_forward_expansion_factor (int): Enlarge coefficient of FFN layer. + input_dropout_rate (float): Dropout rate of input projection layer. + pos_enc_layer_type (str): Self attention type. + do_rel_shift (bool): Whether to do relative shift operation on rel-attention module. + cnn_module_kernel (int): Kernel size of CNN module. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + adaptive_scale (bool): Whether to use adaptive scale. + init_weights (bool): Whether to initialize weights. + causal (bool): whether to use causal convolution or not. + """ + super(SqueezeformerEncoder, self).__init__() + self.global_cmvn = global_cmvn + self.reduce_idx = reduce_idx + self.recover_idx = recover_idx + self._output_size = output_size + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + activation = get_activation(activation_type) + + # self-attention module definition + if pos_enc_layer_type != "rel_pos": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + else: + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + encoder_dim, + attention_dropout_rate, + do_rel_shift, + adaptive_scale, + init_weights + ) + + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + encoder_dim, + encoder_dim * feed_forward_expansion_factor, + feed_forward_dropout_rate, + activation, + adaptive_scale, + init_weights + ) + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, + cnn_norm_type, causal, adaptive_scale, init_weights) + + self.embed = DepthwiseConv2dSubsampling4( + 1, encoder_dim, RelPositionalEncoding(encoder_dim, dropout_rate=0.1) + ) + self.input_proj = nn.Sequential( + nn.Linear(encoder_dim * (((input_size - 1) // 2 - 1) // 2), encoder_dim), + nn.Dropout(p=input_dropout_rate), + ) + self.preln = nn.LayerNorm(encoder_dim) + self.encoders = torch.nn.ModuleList() + for layer_id in range(num_blocks): + if layer_id < reduce_idx: + self.encoders.append( + SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), + normalize_before, + dropout, + concat_after + )) + elif reduce_idx <= layer_id < recover_idx: + self.encoders.append( + ResidualModule(SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), + normalize_before, + dropout, + concat_after + ))) + else: + self.encoders.append( + SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), + normalize_before, + dropout, + concat_after + )) + self.time_reduction_layer = TimeReductionLayer(encoder_dim=encoder_dim) + self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim) + self.final_proj = None + if output_size != encoder_dim: + self.final_proj = nn.Linear(encoder_dim, output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + xs_lens = chunk_masks.squeeze(1).sum(1) + xs = self.input_proj(xs) + xs = self.preln(xs) + recover_tensor = torch.tensor(0.) + recover_chunk_masks = torch.tensor(0.) + recover_pos_emb = torch.tensor(0.) + recover_mask_pad = torch.tensor(0.) + for idx, layer in enumerate(self.encoders): + if idx == self.reduce_idx: + recover_tensor = xs + recover_chunk_masks = chunk_masks + recover_pos_emb = pos_emb + recover_mask_pad = mask_pad + xs, xs_lens = self.time_reduction_layer(xs, xs_lens) + reduce_t = xs.size(1) + pos_emb = pos_emb[:, :reduce_t, :] + chunk_masks = chunk_masks[:, ::2, ::2] + mask_pad = mask_pad[:, :, ::2] + + if idx == self.recover_idx: + # recover output length for ctc decode + xs = xs.unsqueeze(2) + xs = xs.repeat(1, 1, 2, 1).flatten(1, 2) + xs = self.time_recover_layer(xs) + recover_t = recover_tensor.size(1) + xs = recover_tensor + xs[:, :recover_t, :].contiguous() + chunk_masks = recover_chunk_masks + pos_emb = recover_pos_emb + mask_pad = recover_mask_pad + + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + + if self.final_proj is not None: + xs = self.final_proj(xs) + return xs, masks \ No newline at end of file diff --git a/wenet/squeezeformer/encoder_layer.py b/wenet/squeezeformer/encoder_layer.py new file mode 100644 index 000000000..ac5d38630 --- /dev/null +++ b/wenet/squeezeformer/encoder_layer.py @@ -0,0 +1,107 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Union, Tuple, List, Dict + + +class SqueezeformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward1 (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + feed_forward2 (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + + def __init__( + self, + size: int, + self_attn: torch.nn.Module, + feed_forward1: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + feed_forward2: Optional[nn.Module] = None, + normalize_before: bool = False, + dropout_rate: float = 0.1, + concat_after: bool = False, + ): + super(SqueezeformerEncoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.layer_norm1 = nn.LayerNorm(size) + self.ffn1 = feed_forward1 + self.layer_norm2 = nn.LayerNorm(size) + self.conv_module = conv_module + self.layer_norm3 = nn.LayerNorm(size) + self.ffn2 = feed_forward2 + self.layer_norm4 = nn.LayerNorm(size) + self.normalize_before = normalize_before + self.dropout = nn.Dropout(dropout_rate) + self.concat_after = concat_after + if concat_after: + self.concat_linear = nn.Linear(size + size, size) + else: + self.concat_linear = nn.Identity() + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # self attention module + residual = x + if self.normalize_before: + x = self.layer_norm1(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.layer_norm1(x) + + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm2(x) + x = self.ffn1(x) + # we do not use dropout here since it is inside feed forward function + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm2(x) + + # conv module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + residual = x + if self.normalize_before: + x = self.layer_norm3(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm3(x) + + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm4(x) + x = self.ffn2(x) + # we do not use dropout here since it is inside feed forward function + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm4(x) + + return x, mask, new_att_cache, new_cnn_cache diff --git a/wenet/squeezeformer/positionwise_feed_forward.py b/wenet/squeezeformer/positionwise_feed_forward.py new file mode 100644 index 000000000..3e2ae45ae --- /dev/null +++ b/wenet/squeezeformer/positionwise_feed_forward.py @@ -0,0 +1,75 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + FeedForward are appied on each position of the sequence. + The output dim is same with the input dim. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + activation (torch.nn.Module): Activation function + """ + + def __init__(self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + adaptive_scale: bool = False, + init_weights: bool = False + ): + """Construct a PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.idim = idim + self.hidden_units = hidden_units + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.activation = activation + self.dropout = torch.nn.Dropout(dropout_rate) + self.w_2 = torch.nn.Linear(hidden_units, idim) + self.adaptive_scale = adaptive_scale + if self.adaptive_scale: + self.scale = torch.nn.Parameter(torch.tensor(1.), requires_grad=True) + self.bias = torch.nn.Parameter(torch.tensor(0.), requires_grad=True) + if init_weights: + self.init_weights() + + def init_weights(self): + ffn1_max = self.idim ** -0.5 + ffn2_max = self.hidden_units ** -0.5 + torch.nn.init.uniform_(self.w_1.weight.data, -ffn1_max, ffn1_max) + torch.nn.init.uniform_(self.w_1.bias.data, -ffn1_max, ffn1_max) + torch.nn.init.uniform_(self.w_2.weight.data, -ffn2_max, ffn2_max) + torch.nn.init.uniform_(self.w_2.bias.data, -ffn2_max, ffn2_max) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + xs: input tensor (B, L, D) + Returns: + output tensor, (B, L, D) + """ + if self.adaptive_scale: + xs = self.scale * xs + self.bias + return self.w_2(self.dropout(self.activation(self.w_1(xs)))) diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py new file mode 100644 index 000000000..69f683615 --- /dev/null +++ b/wenet/squeezeformer/subsampling.py @@ -0,0 +1,100 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from wenet.transformer.subsampling import BaseSubsampling +from typing import Tuple, Union, Optional, Dict +from wenet.squeezeformer.conv2d import Conv2dValid + + +class DepthwiseConv2dSubsampling4(BaseSubsampling): + """Depthwise Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + pos_enc_class (nn.Module): position encoding class. + + """ + + def __init__( + self, idim: int, odim: int, pos_enc_class: torch.nn.Module): + super(DepthwiseConv2dSubsampling4, self).__init__() + self.idim = idim + self.odim = odim + self.pw_conv = nn.Conv2d(in_channels=idim, out_channels=odim, kernel_size=3, stride=2) + self.act1 = nn.ReLU() + self.dw_conv = nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2) + self.act2 = nn.ReLU() + self.pos_enc = pos_enc_class + self.subsampling_rate = 4 + # 6 = (3 - 1) * 1 + (3 - 1) * 2 + self.right_context = 6 + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + offset: int = 0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = x.unsqueeze(1) # (b, c=1, t, f) + x = self.pw_conv(x) + x = self.act1(x) + x = self.dw_conv(x) + x = self.act2(x) + b, c, t, f = x.size() + x = x.permute(0, 2, 1, 3) + x = x.contiguous().view(b, t, c * f) + x, pos_emb = self.pos_enc(x, offset) + return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] + + +class TimeReductionLayer(nn.Module): + def __init__( + self, ichannel: int = 1, ochannel: int = 1, + kernel_size: int = 5, stride: int = 2, encoder_dim: int = 256): + super(TimeReductionLayer, self).__init__() + self.ichannel = ichannel + self.kernel_size = kernel_size + self.dw_conv = Conv2dValid( + in_channels=ichannel, + out_channels=ochannel, + kernel_size=(kernel_size, 1), + stride=stride, + valid_trigy=True + ) + self.pw_conv = Conv2dValid( + in_channels=encoder_dim, + out_channels=encoder_dim, + kernel_size=1, + stride=1, + valid_trigx=False, + valid_trigy=False, + ) + + self.in_channels = ichannel + self.kernel_size = kernel_size + self.stride = stride + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.ichannel ** -0.5 + torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) + + def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + xs = xs.unsqueeze(2) + padding1 = self.kernel_size - self.stride + xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) + xs = self.dw_conv(xs.transpose(1, 2)) + xs = xs.permute(0, 3, 1, 2).contiguous() + xs = self.pw_conv(xs).permute(0, 2, 3, 1).squeeze(1).contiguous() + tmp_length = xs.size(1) + xs_lens = torch.div(xs_lens + 1, 2, rounding_mode='trunc') + padding2 = max(0, (xs_lens.max() - tmp_length).data.item()) + batch_size, hidden = xs.size(0), xs.size(-1) + dummy_pad = torch.zeros(batch_size, padding2, hidden).to(xs.device) + xs = torch.cat([xs, dummy_pad], dim=1) + return xs, xs_lens diff --git a/wenet/squeezeformer/utils.py b/wenet/squeezeformer/utils.py new file mode 100644 index 000000000..eee6c1381 --- /dev/null +++ b/wenet/squeezeformer/utils.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn +import termcolor + + +class ResidualModule(nn.Module): + """ + Residual Connection Module for Squeezeformer Encoder Layer + """ + def __init__(self, layer: nn.Module, coef: float = 1.0): + super(ResidualModule, self).__init__() + self.layer = layer + self.coef = coef + + def forward(self, + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ): + x, mask, new_att_cache, new_cnn_cache = self.layer(x, mask, pos_emb, mask_pad, att_cache, cnn_cache) + x = x * self.coef + x + return x, mask, new_att_cache, new_cnn_cache \ No newline at end of file From 7c24031e2d1b873de995e177a48587261e3d6d4e Mon Sep 17 00:00:00 2001 From: xmly Date: Thu, 15 Sep 2022 13:37:49 +0800 Subject: [PATCH 02/29] [update] enable Squeezeformer training --- examples/librispeech/squeezeformer/README.md | 22 + .../conf/train_squeezeformer.yaml | 88 +++ examples/librispeech/squeezeformer/local | 1 + examples/librispeech/squeezeformer/tools | 1 + examples/librispeech/squeezeformer/wenet | 1 + wenet/bin/train.py | 17 +- wenet/utils/init_model.py | 5 + wenet/utils/scheduler.py | 548 +++++++++++++++++- 8 files changed, 676 insertions(+), 7 deletions(-) create mode 100644 examples/librispeech/squeezeformer/README.md create mode 100644 examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml create mode 120000 examples/librispeech/squeezeformer/local create mode 120000 examples/librispeech/squeezeformer/tools create mode 120000 examples/librispeech/squeezeformer/wenet diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md new file mode 100644 index 000000000..7da562d3d --- /dev/null +++ b/examples/librispeech/squeezeformer/README.md @@ -0,0 +1,22 @@ +# Performance Record + + + +### Conformer +* encoder flops(30s): 2,797,274,624, params: 34,761,608 + + +### Squeezeformer Result (SM12, FFN:1024) +* encoder flops(30s): 21,158,877,440, params: 22,219,912 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: train_conformer_bidecoder_large.yaml, kernel size 31, lr 0.002, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 | +| ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 | +| attention decoder | 8.74 | 3.59 | 3.75 | 8.70 | +| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | + + diff --git a/examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml b/examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml new file mode 100644 index 000000000..202c95f42 --- /dev/null +++ b/examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml @@ -0,0 +1,88 @@ +# network architecture +# encoder related +encoder: squeezeformer +encoder_conf: + encoder_dim: 256 + output_size: 256 # dimension of attention + attention_heads: 4 + num_blocks: 12 # the number of encoder blocks + reduce_idx: 5 + recover_idx: 11 + feed_forward_expansion_factor: 4 + input_dropout_rate: 0.1 + feed_forward_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + cnn_module_kernel: 31 + cnn_norm_type: layer_norm + adaptive_scale: true + normalize_before: false + +# decoder related +decoder: transformer +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 6 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.0 + src_attention_dropout_rate: 0.0 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + +# dataset related +dataset_conf: + even_sample: false + syncbn: false + filter_conf: + max_length: 2000 + min_length: 50 + token_max_length: 400 + token_min_length: 1 + min_output_input_ratio: 0.0005 + max_output_input_ratio: 0.1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 0.1 + spec_aug: true + spec_aug_conf: + num_t_mask: 2 + num_f_mask: 2 + max_t: 50 + max_f: 10 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 12 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 120 +log_interval: 100 + +optim: adamw +optim_conf: + lr: 1.e-3 + weight_decay: 4.e-5 + +scheduler: NoamHoldAnnealing +scheduler_conf: + warmup_ratio: 0.2 + hold_ratio: 0.3 + max_steps: 87960 + decay_rate: 1.0 + min_lr: 1.e-5 diff --git a/examples/librispeech/squeezeformer/local b/examples/librispeech/squeezeformer/local new file mode 120000 index 000000000..ea4a20415 --- /dev/null +++ b/examples/librispeech/squeezeformer/local @@ -0,0 +1 @@ +../s0/local \ No newline at end of file diff --git a/examples/librispeech/squeezeformer/tools b/examples/librispeech/squeezeformer/tools new file mode 120000 index 000000000..c92f4172d --- /dev/null +++ b/examples/librispeech/squeezeformer/tools @@ -0,0 +1 @@ +../../../tools \ No newline at end of file diff --git a/examples/librispeech/squeezeformer/wenet b/examples/librispeech/squeezeformer/wenet new file mode 120000 index 000000000..702de77db --- /dev/null +++ b/examples/librispeech/squeezeformer/wenet @@ -0,0 +1 @@ +../../../wenet \ No newline at end of file diff --git a/wenet/bin/train.py b/wenet/bin/train.py index f03404676..6fabc7b5b 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -31,7 +31,7 @@ load_trained_modules) from wenet.utils.executor import Executor from wenet.utils.file_utils import read_symbol_table, read_non_lang_symbols -from wenet.utils.scheduler import WarmupLR +from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing from wenet.utils.config import override_config from wenet.utils.init_model import init_model @@ -243,8 +243,19 @@ def main(): device = torch.device('cuda' if use_cuda else 'cpu') model = model.to(device) - optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) - scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) + if configs['optim'] == 'adam': + optimizer = optim.Adam(model.parameters(), **configs['optim_conf']) + elif configs['optim'] == 'adamw': + optimizer = optim.AdamW(model.parameters(), **configs['optim_conf']) + else: + raise Exception('Please choose a correct optimizer.') + if configs['scheduler'] == 'warmuplr': + scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) + elif configs['scheduler'] == 'NoamHoldAnnealing': + scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf']) + else: + raise Exception('Please choose a correct scheduler.') + final_epoch = None configs['rank'] = args.rank configs['is_distributed'] = distributed diff --git a/wenet/utils/init_model.py b/wenet/utils/init_model.py index 1e90c7b18..b739c82be 100644 --- a/wenet/utils/init_model.py +++ b/wenet/utils/init_model.py @@ -22,6 +22,7 @@ from wenet.transformer.ctc import CTC from wenet.transformer.decoder import BiTransformerDecoder, TransformerDecoder from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder +from wenet.squeezeformer.encoder import SqueezeformerEncoder from wenet.utils.cmvn import load_cmvn @@ -44,6 +45,10 @@ def init_model(configs): encoder = ConformerEncoder(input_dim, global_cmvn=global_cmvn, **configs['encoder_conf']) + elif encoder_type == 'squeezeformer': + encoder = SqueezeformerEncoder(input_dim, + global_cmvn=global_cmvn, + **configs['encoder_conf']) else: encoder = TransformerEncoder(input_dim, global_cmvn=global_cmvn, diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index e4705cdf4..dcc65022d 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -15,6 +15,8 @@ from typing import Union +import math +import warnings import torch from torch.optim.lr_scheduler import _LRScheduler @@ -39,10 +41,10 @@ class WarmupLR(_LRScheduler): """ def __init__( - self, - optimizer: torch.optim.Optimizer, - warmup_steps: Union[int, float] = 25000, - last_epoch: int = -1, + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 25000, + last_epoch: int = -1, ): assert check_argument_types() self.warmup_steps = warmup_steps @@ -71,3 +73,541 @@ def get_lr(self): def set_step(self, step: int): self.last_epoch = step + + +class WarmupPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1): + assert not ( + warmup_steps is not None and warmup_ratio is not None + ), "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class SquareRootConstantPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__( + self, optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + ): + assert not ( + constant_steps is not None and constant_ratio is not None + ), "Either use particular number of step or ratio" + assert constant_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.constant_lr = 1 / (constant_steps ** 0.5) + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + if step <= self.constant_steps: + return [self.constant_lr for _ in self.base_lrs] + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +class WarmupHoldPolicy(WarmupPolicy): + """Variant of WarmupPolicy which maintains high learning rate for a defined number of steps. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + hold_steps=None, + hold_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not (hold_steps is not None and hold_ratio is not None), "Either use particular number of step or ratio" + assert hold_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + self.min_lr = min_lr + self._last_warmup_lr = 0.0 + + # Necessary to duplicate as class attributes are hidden in inner class + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if hold_steps is not None: + self.hold_steps = hold_steps + self.warmup_steps + elif hold_ratio is not None: + self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps + else: + self.hold_steps = 0 + + super().__init__( + optimizer, + warmup_steps=warmup_steps, + warmup_ratio=warmup_ratio, + max_steps=max_steps, + last_epoch=last_epoch, + min_lr=min_lr, + ) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + # Warmup phase + if step <= self.warmup_steps and self.warmup_steps > 0: + return self._get_warmup_lr(step) + + # Hold phase + if (step >= self.warmup_steps) and (step < self.hold_steps): + return self.base_lrs + + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + +class WarmupAnnealHoldPolicy(_LRScheduler): + """Adds warmup kwargs and warmup logic to lr policy. + All arguments should be passed as kwargs for clarity, + Args: + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + min_lr: Minimum lr to hold the learning rate after decay at. + constant_steps: Number of steps to keep lr constant at. + constant_ratio: Ratio of steps to keep lr constant. + """ + + def __init__( + self, + optimizer, + *, + warmup_steps=None, + warmup_ratio=None, + constant_steps=None, + constant_ratio=None, + max_steps=None, + min_lr=0.0, + last_epoch=-1, + ): + assert not ( + warmup_steps is not None and warmup_ratio is not None + ), "Either use particular number of step or ratio" + assert not ( + constant_steps is not None and constant_ratio is not None + ), "Either use constant_steps or constant_ratio" + assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + if constant_steps is not None: + self.constant_steps = constant_steps + elif constant_ratio is not None: + self.constant_steps = int(constant_ratio * max_steps) + else: + self.constant_steps = 0 + + self.decay_steps = max_steps - (self.constant_steps + self.warmup_steps) + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = self.last_epoch + + # Warmup steps + if self.warmup_steps > 0 and step <= self.warmup_steps: + return self._get_warmup_lr(step) + + # Constant steps after warmup and decay + if self.constant_steps > 0 and (self.warmup_steps + self.decay_steps) < step <= self.max_steps: + return self._get_constant_lr(step) + + # Min lr after max steps of updates + if step > self.max_steps: + return [self.min_lr for _ in self.base_lrs] + + return self._get_lr(step) + + def _get_warmup_lr(self, step): + lr_val = (step + 1) / (self.warmup_steps + 1) + return [initial_lr * lr_val for initial_lr in self.base_lrs] + + def _get_constant_lr(self, step): + return [self.min_lr for _ in self.base_lrs] + + def _get_lr(self, step): + """Simple const lr policy""" + return self.base_lrs + + +def _squareroot_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps) ** 0.5 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _square_annealing(initial_lr, step, max_steps, min_lr): + mult = ((max_steps - step) / max_steps) ** 2 + out_lr = initial_lr * mult + out_lr = max(out_lr, min_lr) + return out_lr + + +def _cosine_annealing(initial_lr, step, max_steps, min_lr): + mult = 0.5 * (1 + math.cos(math.pi * step / max_steps)) + out_lr = (initial_lr - min_lr) * mult + min_lr + return out_lr + + +def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, decay_steps, min_lr): + assert max_lr > min_lr + # Use linear warmup for the initial part. + if warmup_steps > 0 and step <= warmup_steps: + return max_lr * float(step) / float(warmup_steps) + + # For any steps larger than `decay_steps`, use `min_lr`. + if step > warmup_steps + decay_steps: + return min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = step - warmup_steps + decay_steps_ = decay_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + assert decay_ratio >= 0.0 + assert decay_ratio <= 1.0 + delta_lr = max_lr - min_lr + + coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) + + return min_lr + coeff * delta_lr + + +def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): + if cycle: + multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps) + decay_steps *= multiplier + else: + step = min(step, decay_steps) + p = step / decay_steps + lr = (initial_lr - min_lr) * math.pow(1.0 - p, power) + lr += min_lr + return lr + + +def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr): + # hold_steps = total number of steps to hold the LR, not the warmup + hold steps. + T_warmup_decay = max(1, warmup_steps ** decay_rate) + T_hold_decay = max(1, (step - hold_steps) ** decay_rate) + lr = (initial_lr * T_warmup_decay) / T_hold_decay + lr = max(lr, min_lr) + return lr + + +class SquareAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _square_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class SquareRootAnnealing(WarmupPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + new_lrs = [ + _squareroot_annealing(initial_lr=initial_lr, step=step, max_steps=self.max_steps, min_lr=self.min_lr) + for initial_lr in self.base_lrs + ] + return new_lrs + + +class CosineAnnealing(WarmupAnnealHoldPolicy): + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate that was lower than the minimum learning rate." + ) + + if self.constant_steps is None or self.constant_steps == 0: + new_lrs = [ + _cosine_annealing( + initial_lr=initial_lr, + step=step - self.warmup_steps, + max_steps=self.max_steps - self.warmup_steps, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + else: + new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step) + return new_lrs + + def _get_warmup_lr(self, step): + if self.constant_steps is None or self.constant_steps == 0: + return super()._get_warmup_lr(step) + else: + # Use linear warmup for the initial part. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_constant_lr(self, step): + # Only called when `constant_steps` > 0. + return self._get_linear_warmup_with_cosine_annealing_lr(step) + + def _get_linear_warmup_with_cosine_annealing_lr(self, step): + # Cosine Schedule for Megatron LM, slightly different warmup schedule + constant LR at the end. + new_lrs = [ + _linear_warmup_with_cosine_annealing( + max_lr=self.base_lrs[0], + warmup_steps=self.warmup_steps, + step=step, + decay_steps=self.decay_steps, + min_lr=self.min_lr, + ) + for _ in self.base_lrs + ] + return new_lrs + + +class NoamAnnealing(_LRScheduler): + def __init__( + self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + ): + self._normalize = d_model ** (-0.5) + assert not ( + warmup_steps is not None and warmup_ratio is not None + ), "Either use particular number of step or ratio" + assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + + # It is necessary to assign all attributes *before* __init__, + # as class is wrapped by an inner class. + self.max_steps = max_steps + if warmup_steps is not None: + self.warmup_steps = warmup_steps + elif warmup_ratio is not None: + self.warmup_steps = int(warmup_ratio * max_steps) + else: + self.warmup_steps = 0 + + self.min_lr = min_lr + super().__init__(optimizer, last_epoch) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + ) + + step = max(1, self.last_epoch) + + for initial_lr in self.base_lrs: + if initial_lr < self.min_lr: + raise ValueError( + f"{self} received an initial learning rate that was lower than the minimum learning rate." + ) + + new_lrs = [self._noam_annealing(initial_lr=initial_lr, step=step) for initial_lr in self.base_lrs] + return new_lrs + + def _noam_annealing(self, initial_lr, step): + if self.warmup_steps > 0: + mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5))) + else: + mult = self._normalize * step ** (-0.5) + + out_lr = initial_lr * mult + if step > self.warmup_steps: + out_lr = max(out_lr, self.min_lr) + return out_lr + + +class NoamHoldAnnealing(WarmupHoldPolicy): + def __init__(self, optimizer, *, max_steps, decay_rate=0.5, min_lr=0.0, last_epoch=-1, **kwargs): + """ + From Nemo: + Implementation of the Noam Hold Annealing policy from the SqueezeFormer paper. + + Unlike NoamAnnealing, the peak learning rate can be explicitly set for this scheduler. + The schedule first performs linear warmup, then holds the peak LR, then decays with some schedule for + the remainder of the steps. Therefore the min-lr is still dependent on the hyper parameters selected. + + It's schedule is determined by three factors- + + Warmup Steps: Initial stage, where linear warmup occurs uptil the peak LR is reached. Unlike NoamAnnealing, + the peak LR is explicitly stated here instead of a scaling factor. + + Hold Steps: Intermediate stage, where the peak LR is maintained for some number of steps. In this region, + the high peak LR allows the model to converge faster if training is stable. However the high LR + may also cause instability during training. Should usually be a significant fraction of training + steps (around 30-40% of the entire training steps). + + Decay Steps: Final stage, where the LR rapidly decays with some scaling rate (set by decay rate). + To attain Noam decay, use 0.5, for Squeezeformer recommended decay, use 1.0. The fast decay after + prolonged high LR during hold phase allows for rapid convergence. + + References: + - [Squeezeformer: An Efficient Transformer for Automatic Speech Recognition](https://arxiv.org/abs/2206.00888) + + Args: + optimizer: Pytorch compatible Optimizer object. + warmup_steps: Number of training steps in warmup stage + warmup_ratio: Ratio of warmup steps to total steps + hold_steps: Number of training steps to hold the learning rate after warm up + hold_ratio: Ratio of hold steps to total steps + max_steps: Total number of steps while training or `None` for + infinite training + decay_rate: Float value describing the polynomial decay after the hold period. Default value + of 0.5 corresponds to Noam decay. + min_lr: Minimum learning rate. + """ + self.decay_rate = decay_rate + super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + + def _get_lr(self, step): + if self.warmup_steps is None or self.warmup_steps == 0: + raise ValueError("Noam scheduler cannot be used without warmup steps") + + if self.hold_steps > 0: + hold_steps = self.hold_steps - self.warmup_steps + else: + hold_steps = 0 + + new_lrs = [ + _noam_hold_annealing( + initial_lr, + step=step, + warmup_steps=self.warmup_steps, + hold_steps=hold_steps, + decay_rate=self.decay_rate, + min_lr=self.min_lr, + ) + for initial_lr in self.base_lrs + ] + return new_lrs + + def set_step(self, step: int): + self.last_epoch = step From 76ac435e3a9098c4d63ab7663264c5d09d38968c Mon Sep 17 00:00:00 2001 From: xmly Date: Thu, 15 Sep 2022 14:23:08 +0800 Subject: [PATCH 03/29] [update] README.md --- examples/librispeech/squeezeformer/README.md | 27 ++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md index 7da562d3d..b8d1eda38 100644 --- a/examples/librispeech/squeezeformer/README.md +++ b/examples/librispeech/squeezeformer/README.md @@ -1,5 +1,28 @@ -# Performance Record +# Develop Record +``` +squeezeformer +├── attention.py # reltive multi-head attention module +├── conv2d.py # self defined conv2d valid padding module +├── convolution.py # convolution module in squeezeformer block +├── encoder_layer.py # squeezeformer encoder layer +├── encoder.py # squeezeformer encoder class +├── positionwise_feed_forward.py # feed forward layer +├── subsampling.py # sub-sampling layer, time reduction layer +└── utils.py # residual connection module +``` + +* Implementation Details + * Squeezeformer Encoder + * [x] add pre layer norm before squeezeformer block + * [x] derive time reduction layer from tensorflow version + * [x] enable adaptive scale operation + * [x] enable init weights for deep model training + * [x] enable training config and results + * [x] enable dynamic chunk and JIT export + * Training + * [x] enable NoamHoldAnnealing schedular +# Performance Record ### Conformer @@ -9,7 +32,7 @@ ### Squeezeformer Result (SM12, FFN:1024) * encoder flops(30s): 21,158,877,440, params: 22,219,912 * Feature info: using fbank feature, cmvn, dither, online speed perturb -* Training info: train_conformer_bidecoder_large.yaml, kernel size 31, lr 0.002, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 +* Training info: train_squeezeformer.yaml, kernel size 31, lr 0.001, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 * Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 | decoding mode | dev clean | dev other | test clean | test other | From b291d1e2328b4bc656eb8cd7d3fb11cbb68b46dd Mon Sep 17 00:00:00 2001 From: yygle Date: Thu, 15 Sep 2022 17:01:35 +0800 Subject: [PATCH 04/29] fix formatting issues --- examples/librispeech/squeezeformer/README.md | 9 +- wenet/utils/scheduler.py | 155 +++++++++++++------ 2 files changed, 108 insertions(+), 56 deletions(-) diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md index b8d1eda38..c73ae1c27 100644 --- a/examples/librispeech/squeezeformer/README.md +++ b/examples/librispeech/squeezeformer/README.md @@ -17,14 +17,13 @@ squeezeformer * [x] derive time reduction layer from tensorflow version * [x] enable adaptive scale operation * [x] enable init weights for deep model training - * [x] enable training config and results + * [x] enable training config and results * [x] enable dynamic chunk and JIT export * Training - * [x] enable NoamHoldAnnealing schedular + * [x] enable NoamHoldAnnealing schedular # Performance Record - ### Conformer * encoder flops(30s): 2,797,274,624, params: 34,761,608 @@ -40,6 +39,4 @@ squeezeformer | ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 | | ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 | | attention decoder | 8.74 | 3.59 | 3.75 | 8.70 | -| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | - - +| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | \ No newline at end of file diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index dcc65022d..e68255cbb 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -85,11 +85,13 @@ class WarmupPolicy(_LRScheduler): infinite training """ - def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1): + def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, + max_steps=None, min_lr=0.0, last_epoch=-1): assert not ( warmup_steps is not None and warmup_ratio is not None ), "Either use particular number of step or ratio" - assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" # It is necessary to assign all attributes *before* __init__, # as class is wrapped by an inner class. @@ -107,7 +109,9 @@ def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, max_steps def get_lr(self): if not self._get_lr_called_within_step: warnings.warn( - "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning ) step = self.last_epoch @@ -140,12 +144,14 @@ class SquareRootConstantPolicy(_LRScheduler): """ def __init__( - self, optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + self, optimizer, *, constant_steps=None, constant_ratio=None, + max_steps=None, min_lr=0.0, last_epoch=-1 ): assert not ( constant_steps is not None and constant_ratio is not None ), "Either use particular number of step or ratio" - assert constant_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + assert constant_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" # It is necessary to assign all attributes *before* __init__, # as class is wrapped by an inner class. @@ -164,7 +170,9 @@ def __init__( def get_lr(self): if not self._get_lr_called_within_step: warnings.warn( - "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning ) step = self.last_epoch @@ -183,12 +191,14 @@ def _get_lr(self, step): class WarmupHoldPolicy(WarmupPolicy): - """Variant of WarmupPolicy which maintains high learning rate for a defined number of steps. + """Variant of WarmupPolicy which maintains high + learning rate for a defined number of steps. All arguments should be passed as kwargs for clarity, Args: warmup_steps: Number of training steps in warmup stage warmup_ratio: Ratio of warmup steps to total steps - hold_steps: Number of training steps to hold the learning rate after warm up + hold_steps: Number of training steps to + hold the learning rate after warm up hold_ratio: Ratio of hold steps to total steps max_steps: Total number of steps while training or `None` for infinite training @@ -206,8 +216,10 @@ def __init__( min_lr=0.0, last_epoch=-1, ): - assert not (hold_steps is not None and hold_ratio is not None), "Either use particular number of step or ratio" - assert hold_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + assert not (hold_steps is not None and hold_ratio is not None), \ + "Either use particular number of step or ratio" + assert hold_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" self.min_lr = min_lr self._last_warmup_lr = 0.0 @@ -240,7 +252,9 @@ def __init__( def get_lr(self): if not self._get_lr_called_within_step: warnings.warn( - "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning + "To get the last learning rate computed by the scheduler," + " " "please use `get_last_lr()`.", + UserWarning ) step = self.last_epoch @@ -290,7 +304,8 @@ def __init__( assert not ( constant_steps is not None and constant_ratio is not None ), "Either use constant_steps or constant_ratio" - assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" # It is necessary to assign all attributes *before* __init__, # as class is wrapped by an inner class. @@ -318,7 +333,9 @@ def __init__( def get_lr(self): if not self._get_lr_called_within_step: warnings.warn( - "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning ) step = self.last_epoch @@ -328,7 +345,8 @@ def get_lr(self): return self._get_warmup_lr(step) # Constant steps after warmup and decay - if self.constant_steps > 0 and (self.warmup_steps + self.decay_steps) < step <= self.max_steps: + if self.constant_steps > 0 and ( + self.warmup_steps + self.decay_steps) < step <= self.max_steps: return self._get_constant_lr(step) # Min lr after max steps of updates @@ -369,7 +387,8 @@ def _cosine_annealing(initial_lr, step, max_steps, min_lr): return out_lr -def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, decay_steps, min_lr): +def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step, + decay_steps, min_lr): assert max_lr > min_lr # Use linear warmup for the initial part. if warmup_steps > 0 and step <= warmup_steps: @@ -404,8 +423,10 @@ def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle): return lr -def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, decay_rate, min_lr): - # hold_steps = total number of steps to hold the LR, not the warmup + hold steps. +def _noam_hold_annealing(initial_lr, step, warmup_steps, + hold_steps, decay_rate, min_lr): + # hold_steps = total number of steps + # to hold the LR, not the warmup + hold steps. T_warmup_decay = max(1, warmup_steps ** decay_rate) T_hold_decay = max(1, (step - hold_steps) ** decay_rate) lr = (initial_lr * T_warmup_decay) / T_hold_decay @@ -414,8 +435,10 @@ def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps, decay_rate, class SquareAnnealing(WarmupPolicy): - def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, **kwargs): - super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + def __init__(self, optimizer, *, max_steps, min_lr=1e-5, last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, + last_epoch=last_epoch, min_lr=min_lr, **kwargs) def _get_lr(self, step): new_lrs = [ @@ -431,26 +454,32 @@ def _get_lr(self, step): class SquareRootAnnealing(WarmupPolicy): - def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): - super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, + last_epoch=last_epoch, min_lr=min_lr, **kwargs) def _get_lr(self, step): new_lrs = [ - _squareroot_annealing(initial_lr=initial_lr, step=step, max_steps=self.max_steps, min_lr=self.min_lr) + _squareroot_annealing(initial_lr=initial_lr, step=step, + max_steps=self.max_steps, min_lr=self.min_lr) for initial_lr in self.base_lrs ] return new_lrs class CosineAnnealing(WarmupAnnealHoldPolicy): - def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, **kwargs): - super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + def __init__(self, optimizer, *, max_steps, min_lr=0, last_epoch=-1, + **kwargs): + super().__init__(optimizer=optimizer, max_steps=max_steps, + last_epoch=last_epoch, min_lr=min_lr, **kwargs) def _get_lr(self, step): for initial_lr in self.base_lrs: if initial_lr < self.min_lr: raise ValueError( - f"{self} received an initial learning rate that was lower than the minimum learning rate." + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate." ) if self.constant_steps is None or self.constant_steps == 0: @@ -495,13 +524,15 @@ def _get_linear_warmup_with_cosine_annealing_lr(self, step): class NoamAnnealing(_LRScheduler): def __init__( - self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 + self, optimizer, *, d_model, warmup_steps=None, warmup_ratio=None, + max_steps=None, min_lr=0.0, last_epoch=-1 ): self._normalize = d_model ** (-0.5) assert not ( warmup_steps is not None and warmup_ratio is not None ), "Either use particular number of step or ratio" - assert warmup_ratio is None or max_steps is not None, "If there is a ratio, there should be a total steps" + assert warmup_ratio is None or max_steps is not None, \ + "If there is a ratio, there should be a total steps" # It is necessary to assign all attributes *before* __init__, # as class is wrapped by an inner class. @@ -519,7 +550,9 @@ def __init__( def get_lr(self): if not self._get_lr_called_within_step: warnings.warn( - "To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning + "To get the last learning rate computed " + "by the scheduler, please use `get_last_lr()`.", + UserWarning ) step = max(1, self.last_epoch) @@ -527,15 +560,18 @@ def get_lr(self): for initial_lr in self.base_lrs: if initial_lr < self.min_lr: raise ValueError( - f"{self} received an initial learning rate that was lower than the minimum learning rate." + f"{self} received an initial learning rate " + f"that was lower than the minimum learning rate." ) - new_lrs = [self._noam_annealing(initial_lr=initial_lr, step=step) for initial_lr in self.base_lrs] + new_lrs = [self._noam_annealing(initial_lr=initial_lr, step=step) for + initial_lr in self.base_lrs] return new_lrs def _noam_annealing(self, initial_lr, step): if self.warmup_steps > 0: - mult = self._normalize * min(step ** (-0.5), step * (self.warmup_steps ** (-1.5))) + mult = self._normalize * min(step ** (-0.5), + step * (self.warmup_steps ** (-1.5))) else: mult = self._normalize * step ** (-0.5) @@ -546,50 +582,69 @@ def _noam_annealing(self, initial_lr, step): class NoamHoldAnnealing(WarmupHoldPolicy): - def __init__(self, optimizer, *, max_steps, decay_rate=0.5, min_lr=0.0, last_epoch=-1, **kwargs): + def __init__(self, optimizer, *, max_steps, decay_rate=0.5, min_lr=0.0, + last_epoch=-1, **kwargs): """ From Nemo: - Implementation of the Noam Hold Annealing policy from the SqueezeFormer paper. + Implementation of the Noam Hold Annealing policy + from the SqueezeFormer paper. - Unlike NoamAnnealing, the peak learning rate can be explicitly set for this scheduler. - The schedule first performs linear warmup, then holds the peak LR, then decays with some schedule for - the remainder of the steps. Therefore the min-lr is still dependent on the hyper parameters selected. + Unlike NoamAnnealing, the peak learning rate + can be explicitly set for this scheduler. + The schedule first performs linear warmup, + then holds the peak LR, then decays with some schedule for + the remainder of the steps. + Therefore the min-lr is still dependent + on the hyper parameters selected. It's schedule is determined by three factors- - Warmup Steps: Initial stage, where linear warmup occurs uptil the peak LR is reached. Unlike NoamAnnealing, + Warmup Steps: Initial stage, where linear warmup + occurs uptil the peak LR is reached. Unlike NoamAnnealing, the peak LR is explicitly stated here instead of a scaling factor. - Hold Steps: Intermediate stage, where the peak LR is maintained for some number of steps. In this region, - the high peak LR allows the model to converge faster if training is stable. However the high LR - may also cause instability during training. Should usually be a significant fraction of training + Hold Steps: Intermediate stage, where the peak LR + is maintained for some number of steps. In this region, + the high peak LR allows the model to converge faster + if training is stable. However the high LR + may also cause instability during training. + Should usually be a significant fraction of training steps (around 30-40% of the entire training steps). - Decay Steps: Final stage, where the LR rapidly decays with some scaling rate (set by decay rate). - To attain Noam decay, use 0.5, for Squeezeformer recommended decay, use 1.0. The fast decay after - prolonged high LR during hold phase allows for rapid convergence. + Decay Steps: Final stage, where the LR rapidly decays + with some scaling rate (set by decay rate). + To attain Noam decay, use 0.5, + for Squeezeformer recommended decay, use 1.0. + The fast decay after prolonged high LR during + hold phase allows for rapid convergence. References: - - [Squeezeformer: An Efficient Transformer for Automatic Speech Recognition](https://arxiv.org/abs/2206.00888) + - [Squeezeformer: + An Efficient Transformer for Automatic Speech Recognition] + (https://arxiv.org/abs/2206.00888) Args: optimizer: Pytorch compatible Optimizer object. warmup_steps: Number of training steps in warmup stage warmup_ratio: Ratio of warmup steps to total steps - hold_steps: Number of training steps to hold the learning rate after warm up + hold_steps: Number of training steps to + hold the learning rate after warm up hold_ratio: Ratio of hold steps to total steps max_steps: Total number of steps while training or `None` for infinite training - decay_rate: Float value describing the polynomial decay after the hold period. Default value - of 0.5 corresponds to Noam decay. + decay_rate: Float value describing the polynomial decay + after the hold period. Default value + of 0.5 corresponds to Noam decay. min_lr: Minimum learning rate. """ self.decay_rate = decay_rate - super().__init__(optimizer=optimizer, max_steps=max_steps, last_epoch=last_epoch, min_lr=min_lr, **kwargs) + super().__init__(optimizer=optimizer, max_steps=max_steps, + last_epoch=last_epoch, min_lr=min_lr, **kwargs) def _get_lr(self, step): if self.warmup_steps is None or self.warmup_steps == 0: - raise ValueError("Noam scheduler cannot be used without warmup steps") + raise ValueError( + "Noam scheduler cannot be used without warmup steps") if self.hold_steps > 0: hold_steps = self.hold_steps - self.warmup_steps @@ -610,4 +665,4 @@ def _get_lr(self, step): return new_lrs def set_step(self, step: int): - self.last_epoch = step + self.last_epoch = step \ No newline at end of file From 6a36c237dc2b9a114f74be44fb667f65a2ba6714 Mon Sep 17 00:00:00 2001 From: yygle Date: Thu, 15 Sep 2022 17:15:11 +0800 Subject: [PATCH 05/29] fix formatting issues --- examples/librispeech/squeezeformer/README.md | 43 +++++++++++--------- wenet/squeezeformer/conv2d.py | 1 - wenet/squeezeformer/encoder_layer.py | 3 +- wenet/squeezeformer/subsampling.py | 2 +- wenet/squeezeformer/utils.py | 1 - 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md index c73ae1c27..f8af10533 100644 --- a/examples/librispeech/squeezeformer/README.md +++ b/examples/librispeech/squeezeformer/README.md @@ -1,37 +1,40 @@ # Develop Record + ``` squeezeformer -├── attention.py # reltive multi-head attention module -├── conv2d.py # self defined conv2d valid padding module -├── convolution.py # convolution module in squeezeformer block -├── encoder_layer.py # squeezeformer encoder layer -├── encoder.py # squeezeformer encoder class -├── positionwise_feed_forward.py # feed forward layer -├── subsampling.py # sub-sampling layer, time reduction layer -└── utils.py # residual connection module +├── attention.py # reltive multi-head attention module +├── conv2d.py # self defined conv2d valid padding module +├── convolution.py # convolution module in squeezeformer block +├── encoder_layer.py # squeezeformer encoder layer +├── encoder.py # squeezeformer encoder class +├── positionwise_feed_forward.py # feed forward layer +├── subsampling.py # sub-sampling layer, time reduction layer +└── utils.py # residual connection module ``` -* Implementation Details - * Squeezeformer Encoder - * [x] add pre layer norm before squeezeformer block - * [x] derive time reduction layer from tensorflow version - * [x] enable adaptive scale operation - * [x] enable init weights for deep model training - * [x] enable training config and results - * [x] enable dynamic chunk and JIT export - * Training - * [x] enable NoamHoldAnnealing schedular +* Implementation Details + * Squeezeformer Encoder + * [x] add pre layer norm before squeezeformer block + * [x] derive time reduction layer from tensorflow version + * [x] enable adaptive scale operation + * [x] enable init weights for deep model training + * [x] enable training config and results + * [x] enable dynamic chunk and JIT export + * Training + * [x] enable NoamHoldAnnealing schedular # Performance Record ### Conformer -* encoder flops(30s): 2,797,274,624, params: 34,761,608 +* encoder flops(30s): 2,797,274,624, params: 34,761,608 ### Squeezeformer Result (SM12, FFN:1024) + * encoder flops(30s): 21,158,877,440, params: 22,219,912 * Feature info: using fbank feature, cmvn, dither, online speed perturb -* Training info: train_squeezeformer.yaml, kernel size 31, lr 0.001, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 +* Training info: train_squeezeformer.yaml, kernel size 31, lr 0.001, batch size + 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 * Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 | decoding mode | dev clean | dev other | test clean | test other | diff --git a/wenet/squeezeformer/conv2d.py b/wenet/squeezeformer/conv2d.py index 0cc545b33..06ee0bdcf 100644 --- a/wenet/squeezeformer/conv2d.py +++ b/wenet/squeezeformer/conv2d.py @@ -1,4 +1,3 @@ -import torch import torch.nn.functional as F from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional diff --git a/wenet/squeezeformer/encoder_layer.py b/wenet/squeezeformer/encoder_layer.py index ac5d38630..621d774e8 100644 --- a/wenet/squeezeformer/encoder_layer.py +++ b/wenet/squeezeformer/encoder_layer.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F -from typing import Optional, Union, Tuple, List, Dict +from typing import Optional, Tuple class SqueezeformerEncoderLayer(nn.Module): diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py index 69f683615..6b76cb19c 100644 --- a/wenet/squeezeformer/subsampling.py +++ b/wenet/squeezeformer/subsampling.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F from wenet.transformer.subsampling import BaseSubsampling -from typing import Tuple, Union, Optional, Dict +from typing import Tuple from wenet.squeezeformer.conv2d import Conv2dValid diff --git a/wenet/squeezeformer/utils.py b/wenet/squeezeformer/utils.py index eee6c1381..8b469e14e 100644 --- a/wenet/squeezeformer/utils.py +++ b/wenet/squeezeformer/utils.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import termcolor class ResidualModule(nn.Module): From 139fc368b68e07a08526f8f749cd148f0a4d282e Mon Sep 17 00:00:00 2001 From: yygle Date: Thu, 15 Sep 2022 17:26:43 +0800 Subject: [PATCH 06/29] fix formatting issues --- examples/librispeech/squeezeformer/README.md | 2 +- wenet/squeezeformer/attention.py | 6 +- wenet/squeezeformer/conv2d.py | 14 +++-- wenet/squeezeformer/convolution.py | 2 +- wenet/squeezeformer/encoder.py | 61 +++++++++++--------- wenet/squeezeformer/subsampling.py | 12 ++-- wenet/squeezeformer/utils.py | 8 ++- wenet/utils/scheduler.py | 2 +- 8 files changed, 63 insertions(+), 44 deletions(-) diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md index f8af10533..ca911d740 100644 --- a/examples/librispeech/squeezeformer/README.md +++ b/examples/librispeech/squeezeformer/README.md @@ -6,7 +6,7 @@ squeezeformer ├── conv2d.py # self defined conv2d valid padding module ├── convolution.py # convolution module in squeezeformer block ├── encoder_layer.py # squeezeformer encoder layer -├── encoder.py # squeezeformer encoder class +├── encoder.py # squeezeformer encoder class ├── positionwise_feed_forward.py # feed forward layer ├── subsampling.py # sub-sampling layer, time reduction layer └── utils.py # residual connection module diff --git a/wenet/squeezeformer/attention.py b/wenet/squeezeformer/attention.py index e7e06ee13..21a37e689 100644 --- a/wenet/squeezeformer/attention.py +++ b/wenet/squeezeformer/attention.py @@ -14,7 +14,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): dropout_rate (float): Dropout rate. """ - def __init__(self, n_head, n_feat, dropout_rate, do_rel_shift=False, adaptive_scale=False, init_weights=False): + def __init__(self, n_head, n_feat, dropout_rate, + do_rel_shift=False, adaptive_scale=False, init_weights=False): """Construct an RelPositionMultiHeadedAttention object.""" super().__init__(n_head, n_feat, dropout_rate) # linear transformation for positional encoding @@ -100,7 +101,8 @@ def forward_attention( # For last chunk, time2 might be larger than scores.size(-1) mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) scores = scores.masked_fill(mask, -float('inf')) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + # (batch, head, time1, time2) + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # NOTE(xcsong): When will `if mask.size(2) > 0` be False? # 1. onnx(16/-1, -1/-1, 16/0) # 2. jit (16/-1, -1/-1, 16/0, 16/4) diff --git a/wenet/squeezeformer/conv2d.py b/wenet/squeezeformer/conv2d.py index 06ee0bdcf..2acb9e96a 100644 --- a/wenet/squeezeformer/conv2d.py +++ b/wenet/squeezeformer/conv2d.py @@ -28,17 +28,21 @@ def __init__( padding_ = padding if isinstance(padding, str) else _pair(padding) dilation_ = _pair(dilation) super(Conv2dValid, self).__init__( - in_channels, out_channels, kernel_size_, stride_, padding_, dilation_, - False, _pair(0), groups, bias, padding_mode, **factory_kwargs) + in_channels, out_channels, kernel_size_, + stride_, padding_, dilation_, False, _pair(0), + groups, bias, padding_mode, **factory_kwargs) self.valid_trigx = valid_trigx self.valid_trigy = valid_trigy - def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): + def _conv_forward( + self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): validx, validy = 0, 0 if self.valid_trigx: - validx = (input.size(-2) * (self.stride[-2] - 1) - 1 + self.kernel_size[-2]) // 2 + validx = (input.size(-2) * (self.stride[-2] - 1) - 1 + + self.kernel_size[-2]) // 2 if self.valid_trigy: - validy = (input.size(-1) * (self.stride[-1] - 1) - 1 + self.kernel_size[-1]) // 2 + validy = (input.size(-1) * (self.stride[-1] - 1) - 1 + + self.kernel_size[-1]) // 2 return F.conv2d(input, weight, bias, self.stride, (validx, validy), self.dilation, self.groups) diff --git a/wenet/squeezeformer/convolution.py b/wenet/squeezeformer/convolution.py index 055540954..688fe220f 100644 --- a/wenet/squeezeformer/convolution.py +++ b/wenet/squeezeformer/convolution.py @@ -169,4 +169,4 @@ def forward( if mask_pad.size(2) > 0: # time > 0 x.masked_fill_(~mask_pad, 0.0) - return x.transpose(1, 2), new_cache \ No newline at end of file + return x.transpose(1, 2), new_cache diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index 1c23a92b9..4de48dc77 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -2,12 +2,14 @@ import torch.nn as nn from typing import Tuple from wenet.squeezeformer.utils import ResidualModule -from wenet.squeezeformer.subsampling import DepthwiseConv2dSubsampling4, TimeReductionLayer +from wenet.squeezeformer.subsampling \ + import DepthwiseConv2dSubsampling4, TimeReductionLayer from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer from wenet.transformer.embedding import RelPositionalEncoding from wenet.transformer.attention import MultiHeadedAttention from wenet.squeezeformer.attention import RelPositionMultiHeadedAttention -from wenet.squeezeformer.positionwise_feed_forward import PositionwiseFeedForward +from wenet.squeezeformer.positionwise_feed_forward \ + import PositionwiseFeedForward from wenet.squeezeformer.convolution import ConvolutionModule from wenet.utils.mask import make_pad_mask, add_optional_chunk_mask from wenet.utils.common import get_activation @@ -45,26 +47,27 @@ def __init__( ): """Construct SqueezeformerEncoder - Args: - input_size to use_dynamic_chunk, see in Transformer BaseEncoder. - encoder_dim (int): The hidden dimension of encoder layer. - output_size (int): The output dimension of final projection layer. - attention_heads (int): Num of attention head in attention module. - num_blocks (int): Num of encoder layers. - reduce_idx (int): reduce layer index, from 40ms to 80ms per frame. - recover_idx (int): recover layer index, from 80ms to 40ms per frame. - feed_forward_expansion_factor (int): Enlarge coefficient of FFN layer. - input_dropout_rate (float): Dropout rate of input projection layer. - pos_enc_layer_type (str): Self attention type. - do_rel_shift (bool): Whether to do relative shift operation on rel-attention module. - cnn_module_kernel (int): Kernel size of CNN module. - activation_type (str): Encoder activation function type. - use_cnn_module (bool): Whether to use convolution module. - cnn_module_kernel (int): Kernel size of convolution module. - adaptive_scale (bool): Whether to use adaptive scale. - init_weights (bool): Whether to initialize weights. - causal (bool): whether to use causal convolution or not. - """ + Args: + input_size to use_dynamic_chunk, see in Transformer BaseEncoder. + encoder_dim (int): The hidden dimension of encoder layer. + output_size (int): The output dimension of final projection layer. + attention_heads (int): Num of attention head in attention module. + num_blocks (int): Num of encoder layers. + reduce_idx (int): reduce layer index, from 40ms to 80ms per frame. + recover_idx (int): recover layer index, from 80ms to 40ms per frame. + feed_forward_expansion_factor (int): Enlarge coefficient of FFN. + input_dropout_rate (float): Dropout rate of input projection layer. + pos_enc_layer_type (str): Self attention type. + do_rel_shift (bool): Whether to do relative shift + operation on rel-attention module. + cnn_module_kernel (int): Kernel size of CNN module. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + adaptive_scale (bool): Whether to use adaptive scale. + init_weights (bool): Whether to initialize weights. + causal (bool): whether to use causal convolution or not. + """ super(SqueezeformerEncoder, self).__init__() self.global_cmvn = global_cmvn self.reduce_idx = reduce_idx @@ -107,14 +110,17 @@ def __init__( # convolution module definition convolution_layer = ConvolutionModule - convolution_layer_args = (encoder_dim, cnn_module_kernel, activation, - cnn_norm_type, causal, adaptive_scale, init_weights) + convolution_layer_args = ( + encoder_dim, cnn_module_kernel, activation, + cnn_norm_type, causal, adaptive_scale, init_weights) self.embed = DepthwiseConv2dSubsampling4( - 1, encoder_dim, RelPositionalEncoding(encoder_dim, dropout_rate=0.1) + 1, encoder_dim, + RelPositionalEncoding(encoder_dim, dropout_rate=0.1) ) self.input_proj = nn.Sequential( - nn.Linear(encoder_dim * (((input_size - 1) // 2 - 1) // 2), encoder_dim), + nn.Linear( + encoder_dim * (((input_size - 1) // 2 - 1) // 2), encoder_dim), nn.Dropout(p=input_dropout_rate), ) self.preln = nn.LayerNorm(encoder_dim) @@ -218,4 +224,5 @@ def forward( if self.final_proj is not None: xs = self.final_proj(xs) - return xs, masks \ No newline at end of file + return xs, masks + \ No newline at end of file diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py index 6b76cb19c..63bbf7c27 100644 --- a/wenet/squeezeformer/subsampling.py +++ b/wenet/squeezeformer/subsampling.py @@ -21,9 +21,11 @@ def __init__( super(DepthwiseConv2dSubsampling4, self).__init__() self.idim = idim self.odim = odim - self.pw_conv = nn.Conv2d(in_channels=idim, out_channels=odim, kernel_size=3, stride=2) + self.pw_conv = nn.Conv2d( + in_channels=idim, out_channels=odim, kernel_size=3, stride=2) self.act1 = nn.ReLU() - self.dw_conv = nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2) + self.dw_conv = nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2) self.act2 = nn.ReLU() self.pos_enc = pos_enc_class self.subsampling_rate = 4 @@ -84,10 +86,12 @@ def init_weights(self): torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) - def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: xs = xs.unsqueeze(2) padding1 = self.kernel_size - self.stride - xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), mode='constant', value=0.) + xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), + mode='constant', value=0.) xs = self.dw_conv(xs.transpose(1, 2)) xs = xs.permute(0, 3, 1, 2).contiguous() xs = self.pw_conv(xs).permute(0, 2, 3, 1).squeeze(1).contiguous() diff --git a/wenet/squeezeformer/utils.py b/wenet/squeezeformer/utils.py index 8b469e14e..529a62d12 100644 --- a/wenet/squeezeformer/utils.py +++ b/wenet/squeezeformer/utils.py @@ -15,10 +15,12 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor, pos_emb: torch.Tensor, - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones( + (0, 0, 0), dtype=torch.bool), att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), ): - x, mask, new_att_cache, new_cnn_cache = self.layer(x, mask, pos_emb, mask_pad, att_cache, cnn_cache) + x, mask, new_att_cache, new_cnn_cache = self.layer( + x, mask, pos_emb, mask_pad, att_cache, cnn_cache) x = x * self.coef + x - return x, mask, new_att_cache, new_cnn_cache \ No newline at end of file + return x, mask, new_att_cache, new_cnn_cache diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index e68255cbb..d7cfa72c3 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -665,4 +665,4 @@ def _get_lr(self, step): return new_lrs def set_step(self, step: int): - self.last_epoch = step \ No newline at end of file + self.last_epoch = step From 084c8436a181831ca8a0ce09dffe0837f0cf2331 Mon Sep 17 00:00:00 2001 From: yygle Date: Thu, 15 Sep 2022 17:35:27 +0800 Subject: [PATCH 07/29] fix formatting issues --- wenet/squeezeformer/conv2d.py | 2 +- wenet/squeezeformer/encoder.py | 1 - wenet/utils/scheduler.py | 32 ++++++++++++++++---------------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/wenet/squeezeformer/conv2d.py b/wenet/squeezeformer/conv2d.py index 2acb9e96a..37b386002 100644 --- a/wenet/squeezeformer/conv2d.py +++ b/wenet/squeezeformer/conv2d.py @@ -47,4 +47,4 @@ def _conv_forward( (validx, validy), self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.weight, self.bias) \ No newline at end of file + return self._conv_forward(input, self.weight, self.bias) diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index 4de48dc77..bae0c860a 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -225,4 +225,3 @@ def forward( if self.final_proj is not None: xs = self.final_proj(xs) return xs, masks - \ No newline at end of file diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index d7cfa72c3..151b062e1 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -87,9 +87,8 @@ class WarmupPolicy(_LRScheduler): def __init__(self, optimizer, *, warmup_steps=None, warmup_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1): - assert not ( - warmup_steps is not None and warmup_ratio is not None - ), "Either use particular number of step or ratio" + assert not (warmup_steps is not None and warmup_ratio is not None),\ + "Either use particular number of step or ratio" assert warmup_ratio is None or max_steps is not None, \ "If there is a ratio, there should be a total steps" @@ -147,9 +146,9 @@ def __init__( self, optimizer, *, constant_steps=None, constant_ratio=None, max_steps=None, min_lr=0.0, last_epoch=-1 ): - assert not ( - constant_steps is not None and constant_ratio is not None - ), "Either use particular number of step or ratio" + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use particular number of step or ratio" assert constant_ratio is None or max_steps is not None, \ "If there is a ratio, there should be a total steps" @@ -298,12 +297,12 @@ def __init__( min_lr=0.0, last_epoch=-1, ): - assert not ( - warmup_steps is not None and warmup_ratio is not None - ), "Either use particular number of step or ratio" - assert not ( - constant_steps is not None and constant_ratio is not None - ), "Either use constant_steps or constant_ratio" + assert not (warmup_steps is not None + and warmup_ratio is not None), \ + "Either use particular number of step or ratio" + assert not (constant_steps is not None + and constant_ratio is not None), \ + "Either use constant_steps or constant_ratio" assert warmup_ratio is None or max_steps is not None, \ "If there is a ratio, there should be a total steps" @@ -508,7 +507,8 @@ def _get_constant_lr(self, step): return self._get_linear_warmup_with_cosine_annealing_lr(step) def _get_linear_warmup_with_cosine_annealing_lr(self, step): - # Cosine Schedule for Megatron LM, slightly different warmup schedule + constant LR at the end. + # Cosine Schedule for Megatron LM, + # slightly different warmup schedule + constant LR at the end. new_lrs = [ _linear_warmup_with_cosine_annealing( max_lr=self.base_lrs[0], @@ -528,9 +528,9 @@ def __init__( max_steps=None, min_lr=0.0, last_epoch=-1 ): self._normalize = d_model ** (-0.5) - assert not ( - warmup_steps is not None and warmup_ratio is not None - ), "Either use particular number of step or ratio" + assert not (warmup_steps is not None + and warmup_ratio is not None), \ + "Either use particular number of step or ratio" assert warmup_ratio is None or max_steps is not None, \ "If there is a ratio, there should be a total steps" From 1f4a8b3b6337482013f564440e4bfe23ead64f40 Mon Sep 17 00:00:00 2001 From: yygle Date: Thu, 15 Sep 2022 20:21:07 +0800 Subject: [PATCH 08/29] [update] change residual connection & add copyrights --- wenet/squeezeformer/attention.py | 19 +++++ wenet/squeezeformer/conv2d.py | 16 +++++ wenet/squeezeformer/convolution.py | 1 + wenet/squeezeformer/encoder.py | 71 +++++++++---------- wenet/squeezeformer/encoder_layer.py | 16 +++++ .../positionwise_feed_forward.py | 1 + wenet/squeezeformer/subsampling.py | 18 +++++ wenet/squeezeformer/utils.py | 26 ------- wenet/utils/scheduler.py | 2 + 9 files changed, 105 insertions(+), 65 deletions(-) delete mode 100644 wenet/squeezeformer/utils.py diff --git a/wenet/squeezeformer/attention.py b/wenet/squeezeformer/attention.py index 21a37e689..737216665 100644 --- a/wenet/squeezeformer/attention.py +++ b/wenet/squeezeformer/attention.py @@ -1,3 +1,22 @@ +# Copyright (c) 2019 Shigeki Karita +# 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) +# 2022 Ximalaya Inc. (Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multi-Head Attention layer definition.""" + import math import torch import torch.nn as nn diff --git a/wenet/squeezeformer/conv2d.py b/wenet/squeezeformer/conv2d.py index 37b386002..c23026339 100644 --- a/wenet/squeezeformer/conv2d.py +++ b/wenet/squeezeformer/conv2d.py @@ -1,3 +1,19 @@ +# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conv2d Module with Valid Padding""" + import torch.nn.functional as F from torch.nn.modules.conv import _ConvNd, _size_2_t, Union, _pair, Tensor, Optional diff --git a/wenet/squeezeformer/convolution.py b/wenet/squeezeformer/convolution.py index 688fe220f..b2086547f 100644 --- a/wenet/squeezeformer/convolution.py +++ b/wenet/squeezeformer/convolution.py @@ -1,4 +1,5 @@ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) +# 2022 Ximalaya Inc. (authors: Yuguang Yang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index bae0c860a..b51230087 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -1,3 +1,19 @@ +# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer) +# Squeezeformer(https://github.com/upskyy/Squeezeformer) + import torch import torch.nn as nn from typing import Tuple @@ -124,44 +140,16 @@ def __init__( nn.Dropout(p=input_dropout_rate), ) self.preln = nn.LayerNorm(encoder_dim) - self.encoders = torch.nn.ModuleList() - for layer_id in range(num_blocks): - if layer_id < reduce_idx: - self.encoders.append( - SqueezeformerEncoderLayer( - encoder_dim, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - convolution_layer(*convolution_layer_args), - positionwise_layer(*positionwise_layer_args), - normalize_before, - dropout, - concat_after - )) - elif reduce_idx <= layer_id < recover_idx: - self.encoders.append( - ResidualModule(SqueezeformerEncoderLayer( - encoder_dim, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - convolution_layer(*convolution_layer_args), - positionwise_layer(*positionwise_layer_args), - normalize_before, - dropout, - concat_after - ))) - else: - self.encoders.append( - SqueezeformerEncoderLayer( - encoder_dim, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - convolution_layer(*convolution_layer_args), - positionwise_layer(*positionwise_layer_args), - normalize_before, - dropout, - concat_after - )) + self.encoders = torch.nn.ModuleList([SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), + normalize_before, + dropout, + concat_after) for _ in range(num_blocks) + ]) self.time_reduction_layer = TimeReductionLayer(encoder_dim=encoder_dim) self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim) self.final_proj = None @@ -220,7 +208,12 @@ def forward( pos_emb = recover_pos_emb mask_pad = recover_mask_pad - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + if self.reduce_idx <= idx < self.recover_idx: + residual = xs + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + xs += residual + else: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) if self.final_proj is not None: xs = self.final_proj(xs) diff --git a/wenet/squeezeformer/encoder_layer.py b/wenet/squeezeformer/encoder_layer.py index 621d774e8..5d848ec63 100644 --- a/wenet/squeezeformer/encoder_layer.py +++ b/wenet/squeezeformer/encoder_layer.py @@ -1,3 +1,19 @@ +# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SqueezeformerEncoderLayer definition.""" + import torch import torch.nn as nn from typing import Optional, Tuple diff --git a/wenet/squeezeformer/positionwise_feed_forward.py b/wenet/squeezeformer/positionwise_feed_forward.py index 3e2ae45ae..a9ce510a5 100644 --- a/wenet/squeezeformer/positionwise_feed_forward.py +++ b/wenet/squeezeformer/positionwise_feed_forward.py @@ -1,5 +1,6 @@ # Copyright (c) 2019 Shigeki Karita # 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Ximalaya Inc (Yuguang Yang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py index 63bbf7c27..6ce394376 100644 --- a/wenet/squeezeformer/subsampling.py +++ b/wenet/squeezeformer/subsampling.py @@ -1,3 +1,21 @@ +# Copyright (c) 2022 Ximalaya Inc. (authors: Yuguang Yang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer) +# Squeezeformer(https://github.com/upskyy/Squeezeformer) + +"""DepthwiseConv2dSubsampling4 and TimeReductionLayer definition.""" + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/wenet/squeezeformer/utils.py b/wenet/squeezeformer/utils.py deleted file mode 100644 index 529a62d12..000000000 --- a/wenet/squeezeformer/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch -import torch.nn as nn - - -class ResidualModule(nn.Module): - """ - Residual Connection Module for Squeezeformer Encoder Layer - """ - def __init__(self, layer: nn.Module, coef: float = 1.0): - super(ResidualModule, self).__init__() - self.layer = layer - self.coef = coef - - def forward(self, - x: torch.Tensor, - mask: torch.Tensor, - pos_emb: torch.Tensor, - mask_pad: torch.Tensor = torch.ones( - (0, 0, 0), dtype=torch.bool), - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - ): - x, mask, new_att_cache, new_cnn_cache = self.layer( - x, mask, pos_emb, mask_pad, att_cache, cnn_cache) - x = x * self.coef + x - return x, mask, new_att_cache, new_cnn_cache diff --git a/wenet/utils/scheduler.py b/wenet/utils/scheduler.py index 151b062e1..383096cbb 100644 --- a/wenet/utils/scheduler.py +++ b/wenet/utils/scheduler.py @@ -1,4 +1,5 @@ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) +# 2022 Ximalaya Inc (Yuguang Yang) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) +# NeMo(https://github.com/NVIDIA/NeMo) from typing import Union From 0558fbbafeefa6df2cd34c77155282a290f3c55a Mon Sep 17 00:00:00 2001 From: yygle Date: Thu, 15 Sep 2022 20:24:35 +0800 Subject: [PATCH 09/29] fix formatting issues --- wenet/squeezeformer/encoder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index b51230087..783573940 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -17,7 +17,6 @@ import torch import torch.nn as nn from typing import Tuple -from wenet.squeezeformer.utils import ResidualModule from wenet.squeezeformer.subsampling \ import DepthwiseConv2dSubsampling4, TimeReductionLayer from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer From 0264049920c5addb3127e525b72594dacfb7005e Mon Sep 17 00:00:00 2001 From: yygle Date: Thu, 15 Sep 2022 21:02:40 +0800 Subject: [PATCH 10/29] [update] enlarge adaptive scale dimensions --- wenet/squeezeformer/attention.py | 10 +++++----- wenet/squeezeformer/convolution.py | 6 +++--- wenet/squeezeformer/positionwise_feed_forward.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/wenet/squeezeformer/attention.py b/wenet/squeezeformer/attention.py index 737216665..390706b61 100644 --- a/wenet/squeezeformer/attention.py +++ b/wenet/squeezeformer/attention.py @@ -48,8 +48,8 @@ def __init__(self, n_head, n_feat, dropout_rate, torch.nn.init.xavier_uniform_(self.pos_bias_v) self.adaptive_scale = adaptive_scale if self.adaptive_scale: - self.scale = nn.Parameter(torch.tensor(1.), requires_grad=True) - self.bias = nn.Parameter(torch.tensor(0.), requires_grad=True) + self.ada_scale = nn.Parameter(torch.ones(n_feat), requires_grad=True).reshape([1, 1, -1]) + self.ada_bias = nn.Parameter(torch.zeros(n_feat), requires_grad=True).reshape([1, 1, -1]) if init_weights: self.init_weights() @@ -161,9 +161,9 @@ def forward(self, query: torch.Tensor, and `head * d_k == size` """ if self.adaptive_scale: - query = self.scale * query + self.bias - key = self.scale * key + self.bias - value = self.scale * value + self.bias + query = self.ada_scale * query + self.ada_bias + key = self.ada_scale * key + self.ada_bias + value = self.ada_scale * value + self.ada_bias q, k, v = self.forward_qkv(query, key, value) q = q.transpose(1, 2) # (batch, time1, head, d_k) diff --git a/wenet/squeezeformer/convolution.py b/wenet/squeezeformer/convolution.py index b2086547f..5d152cb3d 100644 --- a/wenet/squeezeformer/convolution.py +++ b/wenet/squeezeformer/convolution.py @@ -48,8 +48,8 @@ def __init__(self, self.kernel_size = kernel_size self.adaptive_scale = adaptive_scale if self.adaptive_scale: - self.scale = torch.nn.Parameter(torch.tensor(1.), requires_grad=True) - self.bias = torch.nn.Parameter(torch.tensor(0.), requires_grad=True) + self.ada_scale = torch.nn.Parameter(torch.ones(channels), requires_grad=True).reshape([1, 1, -1]) + self.ada_bias = torch.nn.Parameter(torch.zeros(channels), requires_grad=True).reshape([1, 1, -1]) self.pointwise_conv1 = nn.Conv1d( channels, @@ -132,7 +132,7 @@ def forward( torch.Tensor: Output tensor (#batch, time, channels). """ if self.adaptive_scale: - x = self.scale * x + self.bias + x = self.ada_scale * x + self.ada_bias # exchange the temporal dimension and the feature dimension x = x.transpose(1, 2) # (#batch, channels, time) # mask batch padding diff --git a/wenet/squeezeformer/positionwise_feed_forward.py b/wenet/squeezeformer/positionwise_feed_forward.py index a9ce510a5..39c99db03 100644 --- a/wenet/squeezeformer/positionwise_feed_forward.py +++ b/wenet/squeezeformer/positionwise_feed_forward.py @@ -50,8 +50,8 @@ def __init__(self, self.w_2 = torch.nn.Linear(hidden_units, idim) self.adaptive_scale = adaptive_scale if self.adaptive_scale: - self.scale = torch.nn.Parameter(torch.tensor(1.), requires_grad=True) - self.bias = torch.nn.Parameter(torch.tensor(0.), requires_grad=True) + self.ada_scale = torch.nn.Parameter(torch.ones(idim), requires_grad=True).reshape([1, 1, -1]) + self.ada_bias = torch.nn.Parameter(torch.zeros(idim), requires_grad=True).reshape([1, 1, -1]) if init_weights: self.init_weights() @@ -72,5 +72,5 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: output tensor, (B, L, D) """ if self.adaptive_scale: - xs = self.scale * xs + self.bias + xs = self.ada_scale * xs + self.ada_bias return self.w_2(self.dropout(self.activation(self.w_1(xs)))) From be2f56ed589c22d1f184c4e883c9140ef1f376a9 Mon Sep 17 00:00:00 2001 From: yygle Date: Thu, 15 Sep 2022 21:05:38 +0800 Subject: [PATCH 11/29] fix formatting issues --- wenet/squeezeformer/attention.py | 6 ++++-- wenet/squeezeformer/convolution.py | 6 ++++-- wenet/squeezeformer/positionwise_feed_forward.py | 6 ++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/wenet/squeezeformer/attention.py b/wenet/squeezeformer/attention.py index 390706b61..634e38d0e 100644 --- a/wenet/squeezeformer/attention.py +++ b/wenet/squeezeformer/attention.py @@ -48,8 +48,10 @@ def __init__(self, n_head, n_feat, dropout_rate, torch.nn.init.xavier_uniform_(self.pos_bias_v) self.adaptive_scale = adaptive_scale if self.adaptive_scale: - self.ada_scale = nn.Parameter(torch.ones(n_feat), requires_grad=True).reshape([1, 1, -1]) - self.ada_bias = nn.Parameter(torch.zeros(n_feat), requires_grad=True).reshape([1, 1, -1]) + self.ada_scale = nn.Parameter( + torch.ones(n_feat), requires_grad=True).reshape([1, 1, -1]) + self.ada_bias = nn.Parameter( + torch.zeros(n_feat), requires_grad=True).reshape([1, 1, -1]) if init_weights: self.init_weights() diff --git a/wenet/squeezeformer/convolution.py b/wenet/squeezeformer/convolution.py index 5d152cb3d..a40ae7bdf 100644 --- a/wenet/squeezeformer/convolution.py +++ b/wenet/squeezeformer/convolution.py @@ -48,8 +48,10 @@ def __init__(self, self.kernel_size = kernel_size self.adaptive_scale = adaptive_scale if self.adaptive_scale: - self.ada_scale = torch.nn.Parameter(torch.ones(channels), requires_grad=True).reshape([1, 1, -1]) - self.ada_bias = torch.nn.Parameter(torch.zeros(channels), requires_grad=True).reshape([1, 1, -1]) + self.ada_scale = torch.nn.Parameter( + torch.ones(channels), requires_grad=True).reshape([1, 1, -1]) + self.ada_bias = torch.nn.Parameter( + torch.zeros(channels), requires_grad=True).reshape([1, 1, -1]) self.pointwise_conv1 = nn.Conv1d( channels, diff --git a/wenet/squeezeformer/positionwise_feed_forward.py b/wenet/squeezeformer/positionwise_feed_forward.py index 39c99db03..ddd795e54 100644 --- a/wenet/squeezeformer/positionwise_feed_forward.py +++ b/wenet/squeezeformer/positionwise_feed_forward.py @@ -50,8 +50,10 @@ def __init__(self, self.w_2 = torch.nn.Linear(hidden_units, idim) self.adaptive_scale = adaptive_scale if self.adaptive_scale: - self.ada_scale = torch.nn.Parameter(torch.ones(idim), requires_grad=True).reshape([1, 1, -1]) - self.ada_bias = torch.nn.Parameter(torch.zeros(idim), requires_grad=True).reshape([1, 1, -1]) + self.ada_scale = torch.nn.Parameter( + torch.ones(idim), requires_grad=True).reshape([1, 1, -1]) + self.ada_bias = torch.nn.Parameter( + torch.zeros(idim), requires_grad=True).reshape([1, 1, -1]) if init_weights: self.init_weights() From ba6825c458f81be9295f93018c91113fe9e50420 Mon Sep 17 00:00:00 2001 From: yygle Date: Fri, 16 Sep 2022 11:33:36 +0800 Subject: [PATCH 12/29] fix adaptive scale bugs --- wenet/squeezeformer/attention.py | 4 ++-- wenet/squeezeformer/convolution.py | 4 ++-- wenet/squeezeformer/positionwise_feed_forward.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/wenet/squeezeformer/attention.py b/wenet/squeezeformer/attention.py index 634e38d0e..141113f1b 100644 --- a/wenet/squeezeformer/attention.py +++ b/wenet/squeezeformer/attention.py @@ -49,9 +49,9 @@ def __init__(self, n_head, n_feat, dropout_rate, self.adaptive_scale = adaptive_scale if self.adaptive_scale: self.ada_scale = nn.Parameter( - torch.ones(n_feat), requires_grad=True).reshape([1, 1, -1]) + torch.ones([1, 1, n_feat]), requires_grad=True) self.ada_bias = nn.Parameter( - torch.zeros(n_feat), requires_grad=True).reshape([1, 1, -1]) + torch.zeros([1, 1, n_feat]), requires_grad=True) if init_weights: self.init_weights() diff --git a/wenet/squeezeformer/convolution.py b/wenet/squeezeformer/convolution.py index a40ae7bdf..e3a4e7a75 100644 --- a/wenet/squeezeformer/convolution.py +++ b/wenet/squeezeformer/convolution.py @@ -49,9 +49,9 @@ def __init__(self, self.adaptive_scale = adaptive_scale if self.adaptive_scale: self.ada_scale = torch.nn.Parameter( - torch.ones(channels), requires_grad=True).reshape([1, 1, -1]) + torch.ones([1, 1, channels]), requires_grad=True) self.ada_bias = torch.nn.Parameter( - torch.zeros(channels), requires_grad=True).reshape([1, 1, -1]) + torch.zeros([1, 1, channels]), requires_grad=True) self.pointwise_conv1 = nn.Conv1d( channels, diff --git a/wenet/squeezeformer/positionwise_feed_forward.py b/wenet/squeezeformer/positionwise_feed_forward.py index ddd795e54..e02d2f0dc 100644 --- a/wenet/squeezeformer/positionwise_feed_forward.py +++ b/wenet/squeezeformer/positionwise_feed_forward.py @@ -51,9 +51,9 @@ def __init__(self, self.adaptive_scale = adaptive_scale if self.adaptive_scale: self.ada_scale = torch.nn.Parameter( - torch.ones(idim), requires_grad=True).reshape([1, 1, -1]) + torch.ones([1, 1, idim]), requires_grad=True) self.ada_bias = torch.nn.Parameter( - torch.zeros(idim), requires_grad=True).reshape([1, 1, -1]) + torch.zeros([1, 1, idim]), requires_grad=True) if init_weights: self.init_weights() From 89f133e9a43b297d555165d526594f0875bdf3c5 Mon Sep 17 00:00:00 2001 From: yygle Date: Tue, 20 Sep 2022 19:58:31 +0800 Subject: [PATCH 13/29] [update] encoder.py(fix init weights bugs) and README.md --- examples/librispeech/squeezeformer/README.md | 37 ++++++++++++++++---- wenet/squeezeformer/encoder.py | 2 +- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md index ca911d740..d2c5b5f47 100644 --- a/examples/librispeech/squeezeformer/README.md +++ b/examples/librispeech/squeezeformer/README.md @@ -26,15 +26,25 @@ squeezeformer # Performance Record ### Conformer +* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: train_conformer.yaml, kernel size 31, lr 0.004, +* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 +* Decoding info: ctc_weight 0.5, average_num 30 -* encoder flops(30s): 2,797,274,624, params: 34,761,608 +| decoding mode | test clean | test other | +|----------------------------------|------------|------------| +| ctc greedy search | 3.51 | 9.57 | +| ctc prefix beam search | 3.51 | 9.56 | +| attention decoder | 3.05 | 8.36 | +| attention rescoring | 3.18 | 8.72 | ### Squeezeformer Result (SM12, FFN:1024) - -* encoder flops(30s): 21,158,877,440, params: 22,219,912 +* Encoder FLOPs(30s): 21,158,877,440, params: 22,219,912 * Feature info: using fbank feature, cmvn, dither, online speed perturb -* Training info: train_squeezeformer.yaml, kernel size 31, lr 0.001, batch size - 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 +* Training info: train_squeezeformer.yaml, kernel size 31, +* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 +* AdamW, lr=1e-3, NoamHold, warmup=0.2, hold=0.3, lr_decay=1.0 * Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 | decoding mode | dev clean | dev other | test clean | test other | @@ -42,4 +52,19 @@ squeezeformer | ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 | | ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 | | attention decoder | 8.74 | 3.59 | 3.75 | 8.70 | -| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | \ No newline at end of file +| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | + +### Squeezeformer Result (SM12, FFN:2048) +* Encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: train_squeezeformer.yaml, kernel size 31, +* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 +* AdamW, lr=1e-3, NoamHold, warmup=0.2, hold=0.3, lr_decay=1.0 +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.34 | 9.01 | 3.47 | 8.85 | +| ctc prefix beam search | 3.33 | 9.02 | 3.46 | 8.81 | +| attention decoder | 8.62 | 3.64 | 3.91 | 8.33 | +| attention rescoring | 2.89 | 8.34 | 3.10 | 8.03 | diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index 783573940..b1aa645f3 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -127,7 +127,7 @@ def __init__( convolution_layer = ConvolutionModule convolution_layer_args = ( encoder_dim, cnn_module_kernel, activation, - cnn_norm_type, causal, adaptive_scale, init_weights) + cnn_norm_type, causal, True, adaptive_scale, init_weights) self.embed = DepthwiseConv2dSubsampling4( 1, encoder_dim, From 78b80776158f7aa57094a10b12429b6f2bc55000 Mon Sep 17 00:00:00 2001 From: yygle Date: Wed, 21 Sep 2022 23:55:05 +0800 Subject: [PATCH 14/29] [update] initialization for input projection --- wenet/squeezeformer/encoder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index b1aa645f3..c7ca47c78 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -138,6 +138,10 @@ def __init__( encoder_dim * (((input_size - 1) // 2 - 1) // 2), encoder_dim), nn.Dropout(p=input_dropout_rate), ) + if init_weights: + linear_max = (encoder_dim * input_size / 4) ** -0.5 + torch.nn.init.uniform_(self.input_proj.state_dict()['0.weight'], -linear_max, linear_max) + torch.nn.init.uniform_(self.input_proj.state_dict()['0.bias'], -linear_max, linear_max) self.preln = nn.LayerNorm(encoder_dim) self.encoders = torch.nn.ModuleList([SqueezeformerEncoderLayer( encoder_dim, From 76fbcf20c7bb3ff5767d846944e88b4f8202775a Mon Sep 17 00:00:00 2001 From: yygle Date: Wed, 21 Sep 2022 23:58:31 +0800 Subject: [PATCH 15/29] fix formatting issues --- wenet/squeezeformer/encoder.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index c7ca47c78..7e5cbdbe3 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -140,8 +140,10 @@ def __init__( ) if init_weights: linear_max = (encoder_dim * input_size / 4) ** -0.5 - torch.nn.init.uniform_(self.input_proj.state_dict()['0.weight'], -linear_max, linear_max) - torch.nn.init.uniform_(self.input_proj.state_dict()['0.bias'], -linear_max, linear_max) + torch.nn.init.uniform_( + self.input_proj.state_dict()['0.weight'], -linear_max, linear_max) + torch.nn.init.uniform_( + self.input_proj.state_dict()['0.bias'], -linear_max, linear_max) self.preln = nn.LayerNorm(encoder_dim) self.encoders = torch.nn.ModuleList([SqueezeformerEncoderLayer( encoder_dim, From cefa4cd82f6689579357ac3e86d89323265f4917 Mon Sep 17 00:00:00 2001 From: yygle Date: Thu, 22 Sep 2022 09:36:25 +0800 Subject: [PATCH 16/29] fix formatting issues --- examples/librispeech/squeezeformer/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md index d2c5b5f47..be3140aa6 100644 --- a/examples/librispeech/squeezeformer/README.md +++ b/examples/librispeech/squeezeformer/README.md @@ -51,7 +51,7 @@ squeezeformer |----------------------------------|-----------|-----------|------------|------------| | ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 | | ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 | -| attention decoder | 8.74 | 3.59 | 3.75 | 8.70 | +| attention decoder | 3.59 | 8.74 | 3.75 | 8.70 | | attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | ### Squeezeformer Result (SM12, FFN:2048) @@ -66,5 +66,5 @@ squeezeformer |----------------------------------|-----------|-----------|------------|------------| | ctc greedy search | 3.34 | 9.01 | 3.47 | 8.85 | | ctc prefix beam search | 3.33 | 9.02 | 3.46 | 8.81 | -| attention decoder | 8.62 | 3.64 | 3.91 | 8.33 | +| attention decoder | 3.64 | 8.62 | 3.91 | 8.33 | | attention rescoring | 2.89 | 8.34 | 3.10 | 8.03 | From c2f2a05547de39bf1eb1c68d32b2dba36f0b59ef Mon Sep 17 00:00:00 2001 From: yygle Date: Fri, 23 Sep 2022 17:40:31 +0800 Subject: [PATCH 17/29] [update] time reduction layer with conv1d and conv2d --- .../conf/train_squeezeformer.yaml | 4 +- wenet/bin/train.py | 6 +- wenet/squeezeformer/encoder.py | 28 ++++-- wenet/squeezeformer/subsampling.py | 88 +++++++++++++++++-- 4 files changed, 106 insertions(+), 20 deletions(-) diff --git a/examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml b/examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml index 202c95f42..15dd2d33b 100644 --- a/examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml +++ b/examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml @@ -8,6 +8,8 @@ encoder_conf: num_blocks: 12 # the number of encoder blocks reduce_idx: 5 recover_idx: 11 + pos_enc_layer_type: 'rel_pos' + time_reduction_layer_type: 'conv1d' feed_forward_expansion_factor: 4 input_dropout_rate: 0.1 feed_forward_dropout_rate: 0.1 @@ -36,8 +38,6 @@ model_conf: # dataset related dataset_conf: - even_sample: false - syncbn: false filter_conf: max_length: 2000 min_length: 50 diff --git a/wenet/bin/train.py b/wenet/bin/train.py index 6fabc7b5b..53f14e2af 100644 --- a/wenet/bin/train.py +++ b/wenet/bin/train.py @@ -195,7 +195,7 @@ def main(): model = init_model(configs) print(model) num_params = sum(p.numel() for p in model.parameters()) - print('the number of model params: {}'.format(num_params)) + print('the number of model params: {:,d}'.format(num_params)) # !!!IMPORTANT!!! # Try to export the model by script, if fails, we should refine @@ -248,13 +248,13 @@ def main(): elif configs['optim'] == 'adamw': optimizer = optim.AdamW(model.parameters(), **configs['optim_conf']) else: - raise Exception('Please choose a correct optimizer.') + raise ValueError("unknown optimizer: " + configs['optim']) if configs['scheduler'] == 'warmuplr': scheduler = WarmupLR(optimizer, **configs['scheduler_conf']) elif configs['scheduler'] == 'NoamHoldAnnealing': scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf']) else: - raise Exception('Please choose a correct scheduler.') + raise ValueError("unknown scheduler: " + configs['scheduler']) final_epoch = None configs['rank'] = args.rank diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index 7e5cbdbe3..95dd6d201 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -13,12 +13,13 @@ # limitations under the License. # Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer) # Squeezeformer(https://github.com/upskyy/Squeezeformer) +# NeMo(https://github.com/NVIDIA/NeMo) import torch import torch.nn as nn from typing import Tuple from wenet.squeezeformer.subsampling \ - import DepthwiseConv2dSubsampling4, TimeReductionLayer + import DepthwiseConv2dSubsampling4, TimeReductionLayer1D, TimeReductionLayer2D from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer from wenet.transformer.embedding import RelPositionalEncoding from wenet.transformer.attention import MultiHeadedAttention @@ -43,6 +44,7 @@ def __init__( feed_forward_expansion_factor: int = 4, input_dropout_rate: float = 0.1, pos_enc_layer_type: str = "rel_pos", + time_reduction_layer_type: str = "conv2d", do_rel_shift: bool = True, feed_forward_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.1, @@ -73,6 +75,7 @@ def __init__( feed_forward_expansion_factor (int): Enlarge coefficient of FFN. input_dropout_rate (float): Dropout rate of input projection layer. pos_enc_layer_type (str): Self attention type. + time_reduction_layer_type (str): Conv1d or Conv2d reduction layer. do_rel_shift (bool): Whether to do relative shift operation on rel-attention module. cnn_module_kernel (int): Kernel size of CNN module. @@ -155,7 +158,17 @@ def __init__( dropout, concat_after) for _ in range(num_blocks) ]) - self.time_reduction_layer = TimeReductionLayer(encoder_dim=encoder_dim) + if time_reduction_layer_type == 'conv1d': + time_reduction_layer = TimeReductionLayer1D + time_reduction_layer_args = { + 'channel': encoder_dim, + 'out_dim': encoder_dim, + } + else: + time_reduction_layer = TimeReductionLayer2D + time_reduction_layer_args = {'encoder_dim': encoder_dim} + + self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args) self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim) self.final_proj = None if output_size != encoder_dim: @@ -196,19 +209,16 @@ def forward( recover_chunk_masks = chunk_masks recover_pos_emb = pos_emb recover_mask_pad = mask_pad - xs, xs_lens = self.time_reduction_layer(xs, xs_lens) - reduce_t = xs.size(1) - pos_emb = pos_emb[:, :reduce_t, :] - chunk_masks = chunk_masks[:, ::2, ::2] - mask_pad = mask_pad[:, :, ::2] + xs, xs_lens, chunk_masks, mask_pad = \ + self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad) + pos_emb = pos_emb[:, :xs.size(1), :] if idx == self.recover_idx: # recover output length for ctc decode xs = xs.unsqueeze(2) xs = xs.repeat(1, 1, 2, 1).flatten(1, 2) xs = self.time_recover_layer(xs) - recover_t = recover_tensor.size(1) - xs = recover_tensor + xs[:, :recover_t, :].contiguous() + xs = recover_tensor + xs[:, :recover_tensor.size(1), :].contiguous() chunk_masks = recover_chunk_masks pos_emb = recover_pos_emb mask_pad = recover_mask_pad diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py index 6ce394376..16372104d 100644 --- a/wenet/squeezeformer/subsampling.py +++ b/wenet/squeezeformer/subsampling.py @@ -13,6 +13,7 @@ # limitations under the License. # Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer) # Squeezeformer(https://github.com/upskyy/Squeezeformer) +# NeMo(https://github.com/NVIDIA/NeMo) """DepthwiseConv2dSubsampling4 and TimeReductionLayer definition.""" @@ -68,11 +69,81 @@ def forward( return x, pos_emb, x_mask[:, :, :-2:2][:, :, :-2:2] -class TimeReductionLayer(nn.Module): +class TimeReductionLayer1D(nn.Module): + """ + Modified NeMo, + Squeezeformer Time Reduction procedure. Downsamples the audio by `stride` in the time dimension. + Args: + channel (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward + out_dim (int): Output dimension of the module. + kernel_size (int): Conv kernel size for depthwise convolution in convolution module + stride (int): Downsampling factor in time dimension. + """ + + def __init__(self, channel: int, out_dim: int, kernel_size: int = 5, stride: int = 2): + super(TimeReductionLayer1D, self).__init__() + + self.channel = channel + self.out_dim = out_dim + self.kernel_size = kernel_size + self.stride = stride + self.padding = max(0, self.kernel_size - self.stride) + + self.dw_conv = nn.Conv1d( + in_channels=channel, + out_channels=channel, + kernel_size=kernel_size, + stride=stride, + padding=self.padding, + groups=channel, + ) + + self.pw_conv = nn.Conv1d( + in_channels=channel, out_channels=out_dim, kernel_size=1, stride=1, padding=0, groups=1, + ) + + self.init_weights() + + def init_weights(self): + dw_max = self.kernel_size ** -0.5 + pw_max = self.channel ** -0.5 + torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) + + def forward(self, xs, xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ): + xs = xs.transpose(1, 2) # [B, C, T] + xs = xs.masked_fill(mask_pad.eq(0), 0.0) + + xs = self.dw_conv(xs) + xs = self.pw_conv(xs) + + xs = xs.transpose(1, 2) # [B, T, C] + + B, T, D = xs.size() + mask = mask[:, ::self.stride, ::self.stride] + mask_pad = mask_pad[:, :, ::self.stride] + L = mask_pad.size(-1) + # For JIT exporting, we remove F.pad operator. + if L - T < 0: + xs = xs[:, :L - T, :].contiguous() + else: + dummy_pad = torch.zeros(B, L - T, D, device=xs.device) + xs = torch.cat([xs, dummy_pad], dim=1) + + xs_lens = torch.div(xs_lens + 1, 2, rounding_mode='trunc') + return xs, xs_lens, mask, mask_pad + + +class TimeReductionLayer2D(nn.Module): def __init__( self, ichannel: int = 1, ochannel: int = 1, kernel_size: int = 5, stride: int = 2, encoder_dim: int = 256): - super(TimeReductionLayer, self).__init__() + super(TimeReductionLayer2D, self).__init__() self.ichannel = ichannel self.kernel_size = kernel_size self.dw_conv = Conv2dValid( @@ -104,8 +175,11 @@ def init_weights(self): torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) - def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + xs = xs.masked_fill(mask_pad.transpose(1, 2).eq(0), 0.0) xs = xs.unsqueeze(2) padding1 = self.kernel_size - self.stride xs = F.pad(xs, (0, 0, 0, 0, 0, padding1, 0, 0), @@ -117,6 +191,8 @@ def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor xs_lens = torch.div(xs_lens + 1, 2, rounding_mode='trunc') padding2 = max(0, (xs_lens.max() - tmp_length).data.item()) batch_size, hidden = xs.size(0), xs.size(-1) - dummy_pad = torch.zeros(batch_size, padding2, hidden).to(xs.device) + dummy_pad = torch.zeros(batch_size, padding2, hidden, device=xs.device) xs = torch.cat([xs, dummy_pad], dim=1) - return xs, xs_lens + mask = mask[:, ::2, ::2] + mask_pad = mask_pad[:, :, ::2] + return xs, xs_lens, mask, mask_pad From ba7ed740bc21bac774717c3ba601e2bc5f17b1a3 Mon Sep 17 00:00:00 2001 From: yygle Date: Fri, 23 Sep 2022 17:52:35 +0800 Subject: [PATCH 18/29] fix formatting issues --- wenet/squeezeformer/subsampling.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py index 16372104d..21a9997a7 100644 --- a/wenet/squeezeformer/subsampling.py +++ b/wenet/squeezeformer/subsampling.py @@ -72,15 +72,19 @@ def forward( class TimeReductionLayer1D(nn.Module): """ Modified NeMo, - Squeezeformer Time Reduction procedure. Downsamples the audio by `stride` in the time dimension. + Squeezeformer Time Reduction procedure. + Downsamples the audio by `stride` in the time dimension. Args: - channel (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward + channel (int): input dimension of + MultiheadAttention and PositionwiseFeedForward out_dim (int): Output dimension of the module. - kernel_size (int): Conv kernel size for depthwise convolution in convolution module + kernel_size (int): Conv kernel size + for depthwise convolution in convolution module stride (int): Downsampling factor in time dimension. """ - def __init__(self, channel: int, out_dim: int, kernel_size: int = 5, stride: int = 2): + def __init__(self, channel: int, out_dim: int, + kernel_size: int = 5, stride: int = 2): super(TimeReductionLayer1D, self).__init__() self.channel = channel @@ -99,7 +103,12 @@ def __init__(self, channel: int, out_dim: int, kernel_size: int = 5, stride: int ) self.pw_conv = nn.Conv1d( - in_channels=channel, out_channels=out_dim, kernel_size=1, stride=1, padding=0, groups=1, + in_channels=channel, + out_channels=out_dim, + kernel_size=1, + stride=1, + padding=0, + groups=1, ) self.init_weights() From 027c85c32b9ec6eae31f059d427a3020cef4bc2f Mon Sep 17 00:00:00 2001 From: yygle Date: Fri, 23 Sep 2022 18:15:00 +0800 Subject: [PATCH 19/29] [update] operators --- wenet/squeezeformer/encoder.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index 95dd6d201..290112341 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -215,10 +215,9 @@ def forward( if idx == self.recover_idx: # recover output length for ctc decode - xs = xs.unsqueeze(2) - xs = xs.repeat(1, 1, 2, 1).flatten(1, 2) + xs = torch.repeat_interleave(xs, repeats=2, dim=1) xs = self.time_recover_layer(xs) - xs = recover_tensor + xs[:, :recover_tensor.size(1), :].contiguous() + xs = recover_tensor + xs[:, :recover_tensor.size(1), :] chunk_masks = recover_chunk_masks pos_emb = recover_pos_emb mask_pad = recover_mask_pad From ed342f259b5af7b766c68a8c3332928cdd06bc44 Mon Sep 17 00:00:00 2001 From: yygle Date: Sun, 25 Sep 2022 22:19:34 +0800 Subject: [PATCH 20/29] [update] experiment results & code format --- examples/librispeech/squeezeformer/README.md | 55 ++++++++++++++++---- wenet/squeezeformer/encoder.py | 15 +++--- wenet/squeezeformer/subsampling.py | 36 +++++++------ 3 files changed, 73 insertions(+), 33 deletions(-) diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md index be3140aa6..552caff3e 100644 --- a/examples/librispeech/squeezeformer/README.md +++ b/examples/librispeech/squeezeformer/README.md @@ -40,11 +40,17 @@ squeezeformer | attention rescoring | 3.18 | 8.72 | ### Squeezeformer Result (SM12, FFN:1024) -* Encoder FLOPs(30s): 21,158,877,440, params: 22,219,912 -* Feature info: using fbank feature, cmvn, dither, online speed perturb -* Training info: train_squeezeformer.yaml, kernel size 31, -* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 -* AdamW, lr=1e-3, NoamHold, warmup=0.2, hold=0.3, lr_decay=1.0 + +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv2d + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * Encoder FLOPs(30s): 21,158,877,440, params: 22,219,912 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31, + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 + * adamw, lr=1e-3, noamhold, warmup=0.2, hold=0.3, lr_decay=1.0 * Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 | decoding mode | dev clean | dev other | test clean | test other | @@ -55,12 +61,18 @@ squeezeformer | attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | ### Squeezeformer Result (SM12, FFN:2048) -* Encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 + +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv2d + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 * Feature info: using fbank feature, cmvn, dither, online speed perturb -* Training info: train_squeezeformer.yaml, kernel size 31, -* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 -* AdamW, lr=1e-3, NoamHold, warmup=0.2, hold=0.3, lr_decay=1.0 -* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 +* Training info: + * train_squeezeformer.yaml, kernel size 31 + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 + * adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 | decoding mode | dev clean | dev other | test clean | test other | |----------------------------------|-----------|-----------|------------|------------| @@ -68,3 +80,26 @@ squeezeformer | ctc prefix beam search | 3.33 | 9.02 | 3.46 | 8.81 | | attention decoder | 3.64 | 8.62 | 3.91 | 8.33 | | attention rescoring | 2.89 | 8.34 | 3.10 | 8.03 | + +### Squeezeformer Result (ML12, FFN:1312) + +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv1d + * encoder_dim 328, output_size 256, head 4, ffn_dim 328*4=1312 + * encoder FLOPs(30s): 34,103,960,008, params: 35,678,352 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31, + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 + * adamw, lr 5e-4, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.20 | 8.46 | 3.30 | 8.58 | +| ctc prefix beam search | 3.18 | 8.44 | 3.30 | 8.55 | +| attention decoder | 3.38 | 8.31 | 3.89 | 8.32 | +| attention rescoring | 2.81 | 7.86 | 2.96 | 7.91 | diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index 290112341..88a328ff3 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -44,7 +44,7 @@ def __init__( feed_forward_expansion_factor: int = 4, input_dropout_rate: float = 0.1, pos_enc_layer_type: str = "rel_pos", - time_reduction_layer_type: str = "conv2d", + time_reduction_layer_type: str = "conv1d", do_rel_shift: bool = True, feed_forward_dropout_rate: float = 0.1, attention_dropout_rate: float = 0.1, @@ -144,9 +144,11 @@ def __init__( if init_weights: linear_max = (encoder_dim * input_size / 4) ** -0.5 torch.nn.init.uniform_( - self.input_proj.state_dict()['0.weight'], -linear_max, linear_max) + self.input_proj.state_dict()['0.weight'], + -linear_max, linear_max) torch.nn.init.uniform_( - self.input_proj.state_dict()['0.bias'], -linear_max, linear_max) + self.input_proj.state_dict()['0.bias'], + -linear_max, linear_max) self.preln = nn.LayerNorm(encoder_dim) self.encoders = torch.nn.ModuleList([SqueezeformerEncoderLayer( encoder_dim, @@ -161,14 +163,15 @@ def __init__( if time_reduction_layer_type == 'conv1d': time_reduction_layer = TimeReductionLayer1D time_reduction_layer_args = { - 'channel': encoder_dim, - 'out_dim': encoder_dim, + 'ichannel': encoder_dim, + 'ochannel': encoder_dim, } else: time_reduction_layer = TimeReductionLayer2D time_reduction_layer_args = {'encoder_dim': encoder_dim} - self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args) + self.time_reduction_layer = \ + time_reduction_layer(**time_reduction_layer_args) self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim) self.final_proj = None if output_size != encoder_dim: diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py index 21a9997a7..d0263bb57 100644 --- a/wenet/squeezeformer/subsampling.py +++ b/wenet/squeezeformer/subsampling.py @@ -83,28 +83,28 @@ class TimeReductionLayer1D(nn.Module): stride (int): Downsampling factor in time dimension. """ - def __init__(self, channel: int, out_dim: int, + def __init__(self, ichannel: int, ochannel: int, kernel_size: int = 5, stride: int = 2): super(TimeReductionLayer1D, self).__init__() - self.channel = channel - self.out_dim = out_dim + self.ichannel = ichannel + self.ochannel = ochannel self.kernel_size = kernel_size self.stride = stride self.padding = max(0, self.kernel_size - self.stride) self.dw_conv = nn.Conv1d( - in_channels=channel, - out_channels=channel, + in_channels=ichannel, + out_channels=ichannel, kernel_size=kernel_size, stride=stride, padding=self.padding, - groups=channel, + groups=ichannel, ) self.pw_conv = nn.Conv1d( - in_channels=channel, - out_channels=out_dim, + in_channels=ichannel, + out_channels=ochannel, kernel_size=1, stride=1, padding=0, @@ -115,16 +115,17 @@ def __init__(self, channel: int, out_dim: int, def init_weights(self): dw_max = self.kernel_size ** -0.5 - pw_max = self.channel ** -0.5 + pw_max = self.ichannel ** -0.5 torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max) torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) - def forward(self, xs, xs_lens: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ): + def forward( + self, xs, xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ): xs = xs.transpose(1, 2) # [B, C, T] xs = xs.masked_fill(mask_pad.eq(0), 0.0) @@ -184,10 +185,11 @@ def init_weights(self): torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) - def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def forward( + self, xs: torch.Tensor, xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: xs = xs.masked_fill(mask_pad.transpose(1, 2).eq(0), 0.0) xs = xs.unsqueeze(2) padding1 = self.kernel_size - self.stride From ac4013c52e39a5686e76f9b09dc3d6e7c6dfdf13 Mon Sep 17 00:00:00 2001 From: yygle Date: Sun, 25 Sep 2022 22:27:37 +0800 Subject: [PATCH 21/29] [update] experiment results --- examples/librispeech/squeezeformer/README.md | 21 +++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md index 552caff3e..eb7e65421 100644 --- a/examples/librispeech/squeezeformer/README.md +++ b/examples/librispeech/squeezeformer/README.md @@ -39,8 +39,22 @@ squeezeformer | attention decoder | 3.05 | 8.36 | | attention rescoring | 3.18 | 8.72 | -### Squeezeformer Result (SM12, FFN:1024) +### Conformer Result (12 layers, FFN:2048) +* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: train_squeezeformer.yaml, kernel size 31, +* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 +* AdamW, lr 1e-3, NoamHold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.49 | 9.59 | 3.66 | 9.59 | +| ctc prefix beam search | 3.49 | 9.61 | 3.66 | 9.55 | +| attention decoder | 3.52 | 9.04 | 3.85 | 8.97 | +| attention rescoring | 3.10 | 8.91 | 3.29 | 8.81 | +### Squeezeformer Result (SM12, FFN:1024) * Encoder info: * SM12, reduce_idx 5, recover_idx 11, conv2d * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 @@ -61,7 +75,6 @@ squeezeformer | attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | ### Squeezeformer Result (SM12, FFN:2048) - * Encoder info: * SM12, reduce_idx 5, recover_idx 11, conv2d * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 @@ -81,8 +94,7 @@ squeezeformer | attention decoder | 3.64 | 8.62 | 3.91 | 8.33 | | attention rescoring | 2.89 | 8.34 | 3.10 | 8.03 | -### Squeezeformer Result (ML12, FFN:1312) - +### Squeezeformer Result (SM12, FFN:1312) * Encoder info: * SM12, reduce_idx 5, recover_idx 11, conv1d * encoder_dim 328, output_size 256, head 4, ffn_dim 328*4=1312 @@ -96,7 +108,6 @@ squeezeformer * Decoding info: * ctc_weight 0.3, reverse weight 0.5, average_num 30 - | decoding mode | dev clean | dev other | test clean | test other | |----------------------------------|-----------|-----------|------------|------------| | ctc greedy search | 3.20 | 8.46 | 3.30 | 8.58 | From 6592ae3b64c6ba738515628f1fc16769b7cf46d2 Mon Sep 17 00:00:00 2001 From: yygle Date: Sat, 8 Oct 2022 16:48:11 +0800 Subject: [PATCH 22/29] [update] streaming support & results, dw_stride trigger --- examples/librispeech/squeezeformer/README.md | 73 ++- wenet/squeezeformer/encoder_layer.py | 449 +++++++++++++++---- wenet/squeezeformer/subsampling.py | 57 ++- 3 files changed, 452 insertions(+), 127 deletions(-) diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md index eb7e65421..322bac803 100644 --- a/examples/librispeech/squeezeformer/README.md +++ b/examples/librispeech/squeezeformer/README.md @@ -56,8 +56,8 @@ squeezeformer ### Squeezeformer Result (SM12, FFN:1024) * Encoder info: - * SM12, reduce_idx 5, recover_idx 11, conv2d - * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*4=1024 * Encoder FLOPs(30s): 21,158,877,440, params: 22,219,912 * Feature info: * using fbank feature, cmvn, dither, online speed perturb @@ -76,7 +76,7 @@ squeezeformer ### Squeezeformer Result (SM12, FFN:2048) * Encoder info: - * SM12, reduce_idx 5, recover_idx 11, conv2d + * SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 * encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 * Feature info: using fbank feature, cmvn, dither, online speed perturb @@ -96,7 +96,7 @@ squeezeformer ### Squeezeformer Result (SM12, FFN:1312) * Encoder info: - * SM12, reduce_idx 5, recover_idx 11, conv1d + * SM12, reduce_idx 5, recover_idx 11, conv1d, w/o syncbn * encoder_dim 328, output_size 256, head 4, ffn_dim 328*4=1312 * encoder FLOPs(30s): 34,103,960,008, params: 35,678,352 * Feature info: @@ -104,7 +104,7 @@ squeezeformer * Training info: * train_squeezeformer.yaml, kernel size 31, * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 - * adamw, lr 5e-4, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 + * adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 * Decoding info: * ctc_weight 0.3, reverse weight 0.5, average_num 30 @@ -114,3 +114,66 @@ squeezeformer | ctc prefix beam search | 3.18 | 8.44 | 3.30 | 8.55 | | attention decoder | 3.38 | 8.31 | 3.89 | 8.32 | | attention rescoring | 2.81 | 7.86 | 2.96 | 7.91 | + +### Conformer U2++ Result + +* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 +* Feature info: using fbank feature, cmvn, no speed perturb, dither +* Training info: train_u2++_conformer.yaml lr 0.001, batch size 24, 8 gpu, acc_grad 1, 120 epochs, dither 1.0 +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 +* Git hash: 65270043fc8c2476d1ab95e7c39f730017a670e0 + +test clean + +| decoding mode | full | 16 | +|--------------------------------|------|------| +| ctc prefix beam search | 3.76 | 4.54 | +| attention rescoring | 3.32 | 3.80 | + +test other + +| decoding mode | full | 16 | +|--------------------------------|-------|-------| +| ctc prefix beam search | 9.50 | 11.52 | +| attention rescoring | 8.67 | 10.38 | + +### Squeezeformer Result (U2++, FFN:2048) + +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv1d, layer_norm, do_rel_shift false + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * Encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31 + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 + * adamw, lr 1e-3, NoamHold, warmup 0.1, hold 0.4, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + +test clean + +| decoding mode | full | 16 | +|--------------------------------|------|------| +| ctc prefix beam search | 3.81 | 4.59 | +| attention rescoring | 3.36 | 3.93 | + +test other + +| decoding mode | full | 16 | +|--------------------------------|-------|-------| +| ctc prefix beam search | 9.12 | 11.17 | +| attention rescoring | 8.43 | 10.21 | + +### Conformer Result Bidecoder (large) + +* Encoder FLOPs(30s): 96,238,430,720, params: 85,709,704 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: train_conformer_bidecoder_large.yaml, kernel size 31, lr 0.002, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | test clean | test other | +|----------------------------------|------------|------------| +| ctc prefix beam search | 2.96 | 7.14 | +| attention rescoring | 2.66 | 6.53 | diff --git a/wenet/squeezeformer/encoder_layer.py b/wenet/squeezeformer/encoder_layer.py index 5d848ec63..f90868a71 100644 --- a/wenet/squeezeformer/encoder_layer.py +++ b/wenet/squeezeformer/encoder_layer.py @@ -11,112 +11,375 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -"""SqueezeformerEncoderLayer definition.""" +# Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer) +# Squeezeformer(https://github.com/upskyy/Squeezeformer) +# NeMo(https://github.com/NVIDIA/NeMo) import torch import torch.nn as nn -from typing import Optional, Tuple +from typing import Tuple +from wenet.squeezeformer.subsampling \ + import DepthwiseConv2dSubsampling4, TimeReductionLayer1D, TimeReductionLayer2D +from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer +from wenet.transformer.embedding import RelPositionalEncoding +from wenet.transformer.attention import MultiHeadedAttention +from wenet.squeezeformer.attention import RelPositionMultiHeadedAttention +from wenet.squeezeformer.positionwise_feed_forward \ + import PositionwiseFeedForward +from wenet.squeezeformer.convolution import ConvolutionModule +from wenet.utils.mask import make_pad_mask, add_optional_chunk_mask +from wenet.utils.common import get_activation -class SqueezeformerEncoderLayer(nn.Module): - """Encoder layer module. - Args: - size (int): Input dimension. - self_attn (torch.nn.Module): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` - instance can be used as the argument. - feed_forward1 (torch.nn.Module): Feed-forward module instance. - `PositionwiseFeedForward` instance can be used as the argument. - conv_module (torch.nn.Module): Convolution module instance. - `ConvlutionModule` instance can be used as the argument. - feed_forward2 (torch.nn.Module): Feed-forward module instance. - `PositionwiseFeedForward` instance can be used as the argument. - dropout_rate (float): Dropout rate. - normalize_before (bool): - True: use layer_norm before each sub-block. - False: use layer_norm after each sub-block. - """ - +class SqueezeformerEncoder(nn.Module): def __init__( self, - size: int, - self_attn: torch.nn.Module, - feed_forward1: Optional[nn.Module] = None, - conv_module: Optional[nn.Module] = None, - feed_forward2: Optional[nn.Module] = None, + input_size: int = 80, + encoder_dim: int = 256, + output_size: int = 256, + attention_heads: int = 4, + num_blocks: int = 12, + reduce_idx: int = 5, + recover_idx: int = 11, + feed_forward_expansion_factor: int = 4, + dw_stride: bool = False, + input_dropout_rate: float = 0.1, + pos_enc_layer_type: str = "rel_pos", + time_reduction_layer_type: str = "conv1d", + do_rel_shift: bool = True, + feed_forward_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.1, + cnn_module_kernel: int = 31, + cnn_norm_type: str = "batch_norm", + dropout: float = 0.1, + causal: bool = False, + adaptive_scale: bool = True, + activation_type: str = "swish", + init_weights: bool = True, + global_cmvn: torch.nn.Module = None, normalize_before: bool = False, - dropout_rate: float = 0.1, + use_dynamic_chunk: bool = False, concat_after: bool = False, + static_chunk_size: int = 0, + use_dynamic_left_chunk: bool = False ): - super(SqueezeformerEncoderLayer, self).__init__() - self.size = size - self.self_attn = self_attn - self.layer_norm1 = nn.LayerNorm(size) - self.ffn1 = feed_forward1 - self.layer_norm2 = nn.LayerNorm(size) - self.conv_module = conv_module - self.layer_norm3 = nn.LayerNorm(size) - self.ffn2 = feed_forward2 - self.layer_norm4 = nn.LayerNorm(size) - self.normalize_before = normalize_before - self.dropout = nn.Dropout(dropout_rate) - self.concat_after = concat_after - if concat_after: - self.concat_linear = nn.Linear(size + size, size) + """Construct SqueezeformerEncoder + + Args: + input_size to use_dynamic_chunk, see in Transformer BaseEncoder. + encoder_dim (int): The hidden dimension of encoder layer. + output_size (int): The output dimension of final projection layer. + attention_heads (int): Num of attention head in attention module. + num_blocks (int): Num of encoder layers. + reduce_idx (int): reduce layer index, from 40ms to 80ms per frame. + recover_idx (int): recover layer index, from 80ms to 40ms per frame. + feed_forward_expansion_factor (int): Enlarge coefficient of FFN. + dw_stride (bool): Whether do depthwise convolution + on subsampling module. + input_dropout_rate (float): Dropout rate of input projection layer. + pos_enc_layer_type (str): Self attention type. + do_rel_shift (bool): Whether to do relative shift + operation on rel-attention module. + cnn_module_kernel (int): Kernel size of CNN module. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + cnn_module_kernel (int): Kernel size of convolution module. + adaptive_scale (bool): Whether to use adaptive scale. + init_weights (bool): Whether to initialize weights. + causal (bool): whether to use causal convolution or not. + """ + super(SqueezeformerEncoder, self).__init__() + self.global_cmvn = global_cmvn + self.reduce_idx = reduce_idx + self.recover_idx = recover_idx + self._output_size = output_size + self.static_chunk_size = static_chunk_size + self.use_dynamic_chunk = use_dynamic_chunk + self.use_dynamic_left_chunk = use_dynamic_left_chunk + activation = get_activation(activation_type) + + # self-attention module definition + if pos_enc_layer_type != "rel_pos": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + else: + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + encoder_dim, + attention_dropout_rate, + do_rel_shift, + adaptive_scale, + init_weights + ) + + # feed-forward module definition + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + encoder_dim, + encoder_dim * feed_forward_expansion_factor, + feed_forward_dropout_rate, + activation, + adaptive_scale, + init_weights + ) + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = ( + encoder_dim, cnn_module_kernel, activation, + cnn_norm_type, causal, True, adaptive_scale, init_weights) + + self.embed = DepthwiseConv2dSubsampling4( + 1, encoder_dim, + RelPositionalEncoding(encoder_dim, dropout_rate=0.1), + dw_stride + ) + self.input_proj = nn.Sequential( + nn.Linear( + encoder_dim * (((input_size - 1) // 2 - 1) // 2), encoder_dim), + nn.Dropout(p=input_dropout_rate), + ) + if init_weights: + linear_max = (encoder_dim * input_size / 4) ** -0.5 + torch.nn.init.uniform_( + self.input_proj.state_dict()['0.weight'], -linear_max, linear_max) + torch.nn.init.uniform_( + self.input_proj.state_dict()['0.bias'], -linear_max, linear_max) + self.preln = nn.LayerNorm(encoder_dim) + self.encoders = torch.nn.ModuleList([SqueezeformerEncoderLayer( + encoder_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + convolution_layer(*convolution_layer_args), + positionwise_layer(*positionwise_layer_args), + normalize_before, + dropout, + concat_after) for _ in range(num_blocks) + ]) + if time_reduction_layer_type == 'conv1d': + time_reduction_layer = TimeReductionLayer1D + time_reduction_layer_args = { + 'channel': encoder_dim, + 'out_dim': encoder_dim, + } else: - self.concat_linear = nn.Identity() + time_reduction_layer = TimeReductionLayer2D + time_reduction_layer_args = {'encoder_dim': encoder_dim} + + self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args) + self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim) + self.final_proj = None + if output_size != encoder_dim: + self.final_proj = nn.Linear(encoder_dim, output_size) + + def output_size(self) -> int: + return self._output_size def forward( self, - x: torch.Tensor, - mask: torch.Tensor, - pos_emb: torch.Tensor, - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # self attention module - residual = x - if self.normalize_before: - x = self.layer_norm1(x) - x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) - if self.concat_after: - x_concat = torch.cat((x, x_att), dim=-1) - x = residual + self.concat_linear(x_concat) - else: - x = residual + self.dropout(x_att) - if not self.normalize_before: - x = self.layer_norm1(x) + xs: torch.Tensor, + xs_lens: torch.Tensor, + decoding_chunk_size: int = 0, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + T = xs.size(1) + masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + xs, pos_emb, masks = self.embed(xs, masks) + mask_pad = masks # (B, 1, T/subsample_rate) + chunk_masks = add_optional_chunk_mask(xs, masks, + self.use_dynamic_chunk, + self.use_dynamic_left_chunk, + decoding_chunk_size, + self.static_chunk_size, + num_decoding_left_chunks) + xs_lens = chunk_masks.squeeze(1).sum(1) + xs = self.input_proj(xs) + xs = self.preln(xs) + recover_tensor = torch.tensor(0.) + recover_chunk_masks = torch.tensor(0.) + recover_pos_emb = torch.tensor(0.) + recover_mask_pad = torch.tensor(0.) + for idx, layer in enumerate(self.encoders): + if idx == self.reduce_idx: + recover_tensor = xs + recover_chunk_masks = chunk_masks + recover_pos_emb = pos_emb + recover_mask_pad = mask_pad + xs, xs_lens, chunk_masks, mask_pad = \ + self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad) + pos_emb = pos_emb[:, :xs.size(1), :] - # ffn module - residual = x - if self.normalize_before: - x = self.layer_norm2(x) - x = self.ffn1(x) - # we do not use dropout here since it is inside feed forward function - x = residual + self.dropout(x) - if not self.normalize_before: - x = self.layer_norm2(x) - - # conv module - new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) - residual = x - if self.normalize_before: - x = self.layer_norm3(x) - x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) - x = residual + self.dropout(x) - if not self.normalize_before: - x = self.layer_norm3(x) - - # ffn module - residual = x + if idx == self.recover_idx: + # recover output length for ctc decode + xs = torch.repeat_interleave(xs, repeats=2, dim=1) + xs = self.time_recover_layer(xs) + recover_t = recover_tensor.size(1) + xs = recover_tensor + xs[:, :recover_t, :].contiguous() + chunk_masks = recover_chunk_masks + pos_emb = recover_pos_emb + mask_pad = recover_mask_pad + + if self.reduce_idx <= idx < self.recover_idx: + residual = xs + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + xs += residual + else: + xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) + + if self.final_proj is not None: + xs = self.final_proj(xs) + return xs, masks + + def forward_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.size(0), att_cache.size(2) + chunk_size = xs.size(1) + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding( + offset=offset - cache_t1, size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + xs, _, new_att_cache, new_cnn_cache = layer( + xs, att_mask, pos_emb, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache + ) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) if self.normalize_before: - x = self.layer_norm4(x) - x = self.ffn2(x) - # we do not use dropout here since it is inside feed forward function - x = residual + self.dropout(x) - if not self.normalize_before: - x = self.layer_norm4(x) - - return x, mask, new_att_cache, new_cnn_cache + xs = self.after_norm(xs) + + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = torch.cat(r_att_cache, dim=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = torch.cat(r_cnn_cache, dim=0) + + return (xs, r_att_cache, r_cnn_cache) + + def forward_chunk_by_chunk( + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not prefered. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + # The model is trained by static or dynamic chunk + assert self.static_chunk_size > 0 or self.use_dynamic_chunk + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, att_cache, cnn_cache) = self.forward_chunk( + chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + outputs.append(y) + offset += y.size(1) + ys = torch.cat(outputs, 1) + masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) + return ys, masks \ No newline at end of file diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py index d0263bb57..8132c2326 100644 --- a/wenet/squeezeformer/subsampling.py +++ b/wenet/squeezeformer/subsampling.py @@ -21,7 +21,7 @@ import torch.nn as nn import torch.nn.functional as F from wenet.transformer.subsampling import BaseSubsampling -from typing import Tuple +from typing import Tuple, Optional from wenet.squeezeformer.conv2d import Conv2dValid @@ -36,7 +36,10 @@ class DepthwiseConv2dSubsampling4(BaseSubsampling): """ def __init__( - self, idim: int, odim: int, pos_enc_class: torch.nn.Module): + self, idim: int, odim: int, + pos_enc_class: torch.nn.Module, + dw_stride: bool = False + ): super(DepthwiseConv2dSubsampling4, self).__init__() self.idim = idim self.odim = odim @@ -44,7 +47,9 @@ def __init__( in_channels=idim, out_channels=odim, kernel_size=3, stride=2) self.act1 = nn.ReLU() self.dw_conv = nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2) + in_channels=odim, out_channels=odim, kernel_size=3, stride=2, + groups=odim if dw_stride else 1 + ) self.act2 = nn.ReLU() self.pos_enc = pos_enc_class self.subsampling_rate = 4 @@ -76,56 +81,51 @@ class TimeReductionLayer1D(nn.Module): Downsamples the audio by `stride` in the time dimension. Args: channel (int): input dimension of - MultiheadAttention and PositionwiseFeedForward + MultiheadAttentionMechanism and PositionwiseFeedForward out_dim (int): Output dimension of the module. - kernel_size (int): Conv kernel size - for depthwise convolution in convolution module + kernel_size (int): Conv kernel size for + depthwise convolution in convolution module stride (int): Downsampling factor in time dimension. """ - def __init__(self, ichannel: int, ochannel: int, + def __init__(self, channel: int, out_dim: int, kernel_size: int = 5, stride: int = 2): super(TimeReductionLayer1D, self).__init__() - self.ichannel = ichannel - self.ochannel = ochannel + self.channel = channel + self.out_dim = out_dim self.kernel_size = kernel_size self.stride = stride self.padding = max(0, self.kernel_size - self.stride) self.dw_conv = nn.Conv1d( - in_channels=ichannel, - out_channels=ichannel, + in_channels=channel, + out_channels=channel, kernel_size=kernel_size, stride=stride, padding=self.padding, - groups=ichannel, + groups=channel, ) self.pw_conv = nn.Conv1d( - in_channels=ichannel, - out_channels=ochannel, - kernel_size=1, - stride=1, - padding=0, - groups=1, + in_channels=channel, out_channels=out_dim, + kernel_size=1, stride=1, padding=0, groups=1, ) self.init_weights() def init_weights(self): dw_max = self.kernel_size ** -0.5 - pw_max = self.ichannel ** -0.5 + pw_max = self.channel ** -0.5 torch.nn.init.uniform_(self.dw_conv.weight, -dw_max, dw_max) torch.nn.init.uniform_(self.dw_conv.bias, -dw_max, dw_max) torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) - def forward( - self, xs, xs_lens: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ): + def forward(self, xs, xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ): xs = xs.transpose(1, 2) # [B, C, T] xs = xs.masked_fill(mask_pad.eq(0), 0.0) @@ -185,11 +185,10 @@ def init_weights(self): torch.nn.init.uniform_(self.pw_conv.weight, -pw_max, pw_max) torch.nn.init.uniform_(self.pw_conv.bias, -pw_max, pw_max) - def forward( - self, xs: torch.Tensor, xs_lens: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor, + mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: xs = xs.masked_fill(mask_pad.transpose(1, 2).eq(0), 0.0) xs = xs.unsqueeze(2) padding1 = self.kernel_size - self.stride From 67e260a7670ff93a744afd4430a2b20842ebf1cc Mon Sep 17 00:00:00 2001 From: yygle Date: Sat, 8 Oct 2022 16:51:29 +0800 Subject: [PATCH 23/29] fix formatting issue --- wenet/squeezeformer/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/squeezeformer/subsampling.py b/wenet/squeezeformer/subsampling.py index 8132c2326..d9cd89973 100644 --- a/wenet/squeezeformer/subsampling.py +++ b/wenet/squeezeformer/subsampling.py @@ -21,7 +21,7 @@ import torch.nn as nn import torch.nn.functional as F from wenet.transformer.subsampling import BaseSubsampling -from typing import Tuple, Optional +from typing import Tuple from wenet.squeezeformer.conv2d import Conv2dValid From 59733527b814d9540c78a5a3813a27138cbd960e Mon Sep 17 00:00:00 2001 From: yygle Date: Sun, 9 Oct 2022 22:37:04 +0800 Subject: [PATCH 24/29] fix formatting issue --- wenet/squeezeformer/encoder_layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wenet/squeezeformer/encoder_layer.py b/wenet/squeezeformer/encoder_layer.py index f90868a71..dc95ef7fe 100644 --- a/wenet/squeezeformer/encoder_layer.py +++ b/wenet/squeezeformer/encoder_layer.py @@ -382,4 +382,4 @@ def forward_chunk_by_chunk( offset += y.size(1) ys = torch.cat(outputs, 1) masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) - return ys, masks \ No newline at end of file + return ys, masks From 08c49aad6e9b9b77f14844dd08536b3748421134 Mon Sep 17 00:00:00 2001 From: yygle Date: Sun, 9 Oct 2022 22:47:46 +0800 Subject: [PATCH 25/29] fix formatting issue --- wenet/squeezeformer/encoder.py | 147 +++++++++ wenet/squeezeformer/encoder_layer.py | 449 ++++++--------------------- 2 files changed, 240 insertions(+), 356 deletions(-) diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index 88a328ff3..ad4ceec47 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -235,3 +235,150 @@ def forward( if self.final_proj is not None: xs = self.final_proj(xs) return xs, masks + + def forward_chunk( + self, + xs: torch.Tensor, + offset: int, + required_cache_size: int, + att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), + att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ Forward just one chunk + + Args: + xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), + where `time == (chunk_size - 1) * subsample_rate + \ + subsample.right_context + 1` + offset (int): current offset in encoder output time stamp + required_cache_size (int): cache size required for next chunk + compuation + >=0: actual cache size + <0: means all history cache is required + att_cache (torch.Tensor): cache tensor for KEY & VALUE in + transformer/conformer attention, with shape + (elayers, head, cache_t1, d_k * 2), where + `head * d_k == hidden-dim` and + `cache_t1 == chunk_size * num_decoding_left_chunks`. + cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, + (elayers, b=1, hidden-dim, cache_t2), where + `cache_t2 == cnn.lorder - 1` + + Returns: + torch.Tensor: output of current input xs, + with shape (b=1, chunk_size, hidden-dim). + torch.Tensor: new attention cache required for next chunk, with + dynamic shape (elayers, head, ?, d_k * 2) + depending on required_cache_size. + torch.Tensor: new conformer cnn cache required for next chunk, with + same shape as the original cnn_cache. + + """ + assert xs.size(0) == 1 + # tmp_masks is just for interface compatibility + tmp_masks = torch.ones(1, + xs.size(1), + device=xs.device, + dtype=torch.bool) + tmp_masks = tmp_masks.unsqueeze(1) + if self.global_cmvn is not None: + xs = self.global_cmvn(xs) + # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) + xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) + # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) + elayers, cache_t1 = att_cache.size(0), att_cache.size(2) + chunk_size = xs.size(1) + attention_key_size = cache_t1 + chunk_size + pos_emb = self.embed.position_encoding( + offset=offset - cache_t1, size=attention_key_size) + if required_cache_size < 0: + next_cache_start = 0 + elif required_cache_size == 0: + next_cache_start = attention_key_size + else: + next_cache_start = max(attention_key_size - required_cache_size, 0) + r_att_cache = [] + r_cnn_cache = [] + for i, layer in enumerate(self.encoders): + # NOTE(xcsong): Before layer.forward + # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), + # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) + xs, _, new_att_cache, new_cnn_cache = layer( + xs, att_mask, pos_emb, + att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, + cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache + ) + # NOTE(xcsong): After layer.forward + # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), + # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) + r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) + r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) + if self.normalize_before: + xs = self.after_norm(xs) + + # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), + # ? may be larger than cache_t1, it depends on required_cache_size + r_att_cache = torch.cat(r_att_cache, dim=0) + # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) + r_cnn_cache = torch.cat(r_cnn_cache, dim=0) + + return (xs, r_att_cache, r_cnn_cache) + + def forward_chunk_by_chunk( + self, + xs: torch.Tensor, + decoding_chunk_size: int, + num_decoding_left_chunks: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ Forward input chunk by chunk with chunk_size like a streaming + fashion + + Here we should pay special attention to computation cache in the + streaming style forward chunk by chunk. Three things should be taken + into account for computation in the current network: + 1. transformer/conformer encoder layers output cache + 2. convolution in conformer + 3. convolution in subsampling + + However, we don't implement subsampling cache for: + 1. We can control subsampling module to output the right result by + overlapping input instead of cache left context, even though it + wastes some computation, but subsampling only takes a very + small fraction of computation in the whole model. + 2. Typically, there are several covolution layers with subsampling + in subsampling module, it is tricky and complicated to do cache + with different convolution layers with different subsampling + rate. + 3. Currently, nn.Sequential is used to stack all the convolution + layers in subsampling, we need to rewrite it to make it work + with cache, which is not prefered. + Args: + xs (torch.Tensor): (1, max_len, dim) + chunk_size (int): decoding chunk size + """ + assert decoding_chunk_size > 0 + # The model is trained by static or dynamic chunk + assert self.static_chunk_size > 0 or self.use_dynamic_chunk + subsampling = self.embed.subsampling_rate + context = self.embed.right_context + 1 # Add current frame + stride = subsampling * decoding_chunk_size + decoding_window = (decoding_chunk_size - 1) * subsampling + context + num_frames = xs.size(1) + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) + outputs = [] + offset = 0 + required_cache_size = decoding_chunk_size * num_decoding_left_chunks + + # Feed forward overlap input step by step + for cur in range(0, num_frames - context + 1, stride): + end = min(cur + decoding_window, num_frames) + chunk_xs = xs[:, cur:end, :] + (y, att_cache, cnn_cache) = self.forward_chunk( + chunk_xs, offset, required_cache_size, att_cache, cnn_cache) + outputs.append(y) + offset += y.size(1) + ys = torch.cat(outputs, 1) + masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) + return ys, masks diff --git a/wenet/squeezeformer/encoder_layer.py b/wenet/squeezeformer/encoder_layer.py index dc95ef7fe..5d848ec63 100644 --- a/wenet/squeezeformer/encoder_layer.py +++ b/wenet/squeezeformer/encoder_layer.py @@ -11,375 +11,112 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# Modified from Squeezeformer(https://github.com/kssteven418/Squeezeformer) -# Squeezeformer(https://github.com/upskyy/Squeezeformer) -# NeMo(https://github.com/NVIDIA/NeMo) + +"""SqueezeformerEncoderLayer definition.""" import torch import torch.nn as nn -from typing import Tuple -from wenet.squeezeformer.subsampling \ - import DepthwiseConv2dSubsampling4, TimeReductionLayer1D, TimeReductionLayer2D -from wenet.squeezeformer.encoder_layer import SqueezeformerEncoderLayer -from wenet.transformer.embedding import RelPositionalEncoding -from wenet.transformer.attention import MultiHeadedAttention -from wenet.squeezeformer.attention import RelPositionMultiHeadedAttention -from wenet.squeezeformer.positionwise_feed_forward \ - import PositionwiseFeedForward -from wenet.squeezeformer.convolution import ConvolutionModule -from wenet.utils.mask import make_pad_mask, add_optional_chunk_mask -from wenet.utils.common import get_activation +from typing import Optional, Tuple -class SqueezeformerEncoder(nn.Module): +class SqueezeformerEncoderLayer(nn.Module): + """Encoder layer module. + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` + instance can be used as the argument. + feed_forward1 (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + feed_forward2 (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): + True: use layer_norm before each sub-block. + False: use layer_norm after each sub-block. + """ + def __init__( self, - input_size: int = 80, - encoder_dim: int = 256, - output_size: int = 256, - attention_heads: int = 4, - num_blocks: int = 12, - reduce_idx: int = 5, - recover_idx: int = 11, - feed_forward_expansion_factor: int = 4, - dw_stride: bool = False, - input_dropout_rate: float = 0.1, - pos_enc_layer_type: str = "rel_pos", - time_reduction_layer_type: str = "conv1d", - do_rel_shift: bool = True, - feed_forward_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, - cnn_module_kernel: int = 31, - cnn_norm_type: str = "batch_norm", - dropout: float = 0.1, - causal: bool = False, - adaptive_scale: bool = True, - activation_type: str = "swish", - init_weights: bool = True, - global_cmvn: torch.nn.Module = None, + size: int, + self_attn: torch.nn.Module, + feed_forward1: Optional[nn.Module] = None, + conv_module: Optional[nn.Module] = None, + feed_forward2: Optional[nn.Module] = None, normalize_before: bool = False, - use_dynamic_chunk: bool = False, + dropout_rate: float = 0.1, concat_after: bool = False, - static_chunk_size: int = 0, - use_dynamic_left_chunk: bool = False ): - """Construct SqueezeformerEncoder - - Args: - input_size to use_dynamic_chunk, see in Transformer BaseEncoder. - encoder_dim (int): The hidden dimension of encoder layer. - output_size (int): The output dimension of final projection layer. - attention_heads (int): Num of attention head in attention module. - num_blocks (int): Num of encoder layers. - reduce_idx (int): reduce layer index, from 40ms to 80ms per frame. - recover_idx (int): recover layer index, from 80ms to 40ms per frame. - feed_forward_expansion_factor (int): Enlarge coefficient of FFN. - dw_stride (bool): Whether do depthwise convolution - on subsampling module. - input_dropout_rate (float): Dropout rate of input projection layer. - pos_enc_layer_type (str): Self attention type. - do_rel_shift (bool): Whether to do relative shift - operation on rel-attention module. - cnn_module_kernel (int): Kernel size of CNN module. - activation_type (str): Encoder activation function type. - use_cnn_module (bool): Whether to use convolution module. - cnn_module_kernel (int): Kernel size of convolution module. - adaptive_scale (bool): Whether to use adaptive scale. - init_weights (bool): Whether to initialize weights. - causal (bool): whether to use causal convolution or not. - """ - super(SqueezeformerEncoder, self).__init__() - self.global_cmvn = global_cmvn - self.reduce_idx = reduce_idx - self.recover_idx = recover_idx - self._output_size = output_size - self.static_chunk_size = static_chunk_size - self.use_dynamic_chunk = use_dynamic_chunk - self.use_dynamic_left_chunk = use_dynamic_left_chunk - activation = get_activation(activation_type) - - # self-attention module definition - if pos_enc_layer_type != "rel_pos": - encoder_selfattn_layer = MultiHeadedAttention - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - attention_dropout_rate, - ) - else: - encoder_selfattn_layer = RelPositionMultiHeadedAttention - encoder_selfattn_layer_args = ( - attention_heads, - encoder_dim, - attention_dropout_rate, - do_rel_shift, - adaptive_scale, - init_weights - ) - - # feed-forward module definition - positionwise_layer = PositionwiseFeedForward - positionwise_layer_args = ( - encoder_dim, - encoder_dim * feed_forward_expansion_factor, - feed_forward_dropout_rate, - activation, - adaptive_scale, - init_weights - ) - - # convolution module definition - convolution_layer = ConvolutionModule - convolution_layer_args = ( - encoder_dim, cnn_module_kernel, activation, - cnn_norm_type, causal, True, adaptive_scale, init_weights) - - self.embed = DepthwiseConv2dSubsampling4( - 1, encoder_dim, - RelPositionalEncoding(encoder_dim, dropout_rate=0.1), - dw_stride - ) - self.input_proj = nn.Sequential( - nn.Linear( - encoder_dim * (((input_size - 1) // 2 - 1) // 2), encoder_dim), - nn.Dropout(p=input_dropout_rate), - ) - if init_weights: - linear_max = (encoder_dim * input_size / 4) ** -0.5 - torch.nn.init.uniform_( - self.input_proj.state_dict()['0.weight'], -linear_max, linear_max) - torch.nn.init.uniform_( - self.input_proj.state_dict()['0.bias'], -linear_max, linear_max) - self.preln = nn.LayerNorm(encoder_dim) - self.encoders = torch.nn.ModuleList([SqueezeformerEncoderLayer( - encoder_dim, - encoder_selfattn_layer(*encoder_selfattn_layer_args), - positionwise_layer(*positionwise_layer_args), - convolution_layer(*convolution_layer_args), - positionwise_layer(*positionwise_layer_args), - normalize_before, - dropout, - concat_after) for _ in range(num_blocks) - ]) - if time_reduction_layer_type == 'conv1d': - time_reduction_layer = TimeReductionLayer1D - time_reduction_layer_args = { - 'channel': encoder_dim, - 'out_dim': encoder_dim, - } + super(SqueezeformerEncoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.layer_norm1 = nn.LayerNorm(size) + self.ffn1 = feed_forward1 + self.layer_norm2 = nn.LayerNorm(size) + self.conv_module = conv_module + self.layer_norm3 = nn.LayerNorm(size) + self.ffn2 = feed_forward2 + self.layer_norm4 = nn.LayerNorm(size) + self.normalize_before = normalize_before + self.dropout = nn.Dropout(dropout_rate) + self.concat_after = concat_after + if concat_after: + self.concat_linear = nn.Linear(size + size, size) else: - time_reduction_layer = TimeReductionLayer2D - time_reduction_layer_args = {'encoder_dim': encoder_dim} - - self.time_reduction_layer = time_reduction_layer(**time_reduction_layer_args) - self.time_recover_layer = nn.Linear(encoder_dim, encoder_dim) - self.final_proj = None - if output_size != encoder_dim: - self.final_proj = nn.Linear(encoder_dim, output_size) - - def output_size(self) -> int: - return self._output_size + self.concat_linear = nn.Identity() def forward( self, - xs: torch.Tensor, - xs_lens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - T = xs.size(1) - masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks) - xs_lens = chunk_masks.squeeze(1).sum(1) - xs = self.input_proj(xs) - xs = self.preln(xs) - recover_tensor = torch.tensor(0.) - recover_chunk_masks = torch.tensor(0.) - recover_pos_emb = torch.tensor(0.) - recover_mask_pad = torch.tensor(0.) - for idx, layer in enumerate(self.encoders): - if idx == self.reduce_idx: - recover_tensor = xs - recover_chunk_masks = chunk_masks - recover_pos_emb = pos_emb - recover_mask_pad = mask_pad - xs, xs_lens, chunk_masks, mask_pad = \ - self.time_reduction_layer(xs, xs_lens, chunk_masks, mask_pad) - pos_emb = pos_emb[:, :xs.size(1), :] - - if idx == self.recover_idx: - # recover output length for ctc decode - xs = torch.repeat_interleave(xs, repeats=2, dim=1) - xs = self.time_recover_layer(xs) - recover_t = recover_tensor.size(1) - xs = recover_tensor + xs[:, :recover_t, :].contiguous() - chunk_masks = recover_chunk_masks - pos_emb = recover_pos_emb - mask_pad = recover_mask_pad - - if self.reduce_idx <= idx < self.recover_idx: - residual = xs - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - xs += residual - else: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - - if self.final_proj is not None: - xs = self.final_proj(xs) - return xs, masks - - def forward_chunk( - self, - xs: torch.Tensor, - offset: int, - required_cache_size: int, - att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), - att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ Forward just one chunk - - Args: - xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim), - where `time == (chunk_size - 1) * subsample_rate + \ - subsample.right_context + 1` - offset (int): current offset in encoder output time stamp - required_cache_size (int): cache size required for next chunk - compuation - >=0: actual cache size - <0: means all history cache is required - att_cache (torch.Tensor): cache tensor for KEY & VALUE in - transformer/conformer attention, with shape - (elayers, head, cache_t1, d_k * 2), where - `head * d_k == hidden-dim` and - `cache_t1 == chunk_size * num_decoding_left_chunks`. - cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer, - (elayers, b=1, hidden-dim, cache_t2), where - `cache_t2 == cnn.lorder - 1` - - Returns: - torch.Tensor: output of current input xs, - with shape (b=1, chunk_size, hidden-dim). - torch.Tensor: new attention cache required for next chunk, with - dynamic shape (elayers, head, ?, d_k * 2) - depending on required_cache_size. - torch.Tensor: new conformer cnn cache required for next chunk, with - same shape as the original cnn_cache. - - """ - assert xs.size(0) == 1 - # tmp_masks is just for interface compatibility - tmp_masks = torch.ones(1, - xs.size(1), - device=xs.device, - dtype=torch.bool) - tmp_masks = tmp_masks.unsqueeze(1) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - # NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim) - xs, pos_emb, _ = self.embed(xs, tmp_masks, offset) - # NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim) - elayers, cache_t1 = att_cache.size(0), att_cache.size(2) - chunk_size = xs.size(1) - attention_key_size = cache_t1 + chunk_size - pos_emb = self.embed.position_encoding( - offset=offset - cache_t1, size=attention_key_size) - if required_cache_size < 0: - next_cache_start = 0 - elif required_cache_size == 0: - next_cache_start = attention_key_size - else: - next_cache_start = max(attention_key_size - required_cache_size, 0) - r_att_cache = [] - r_cnn_cache = [] - for i, layer in enumerate(self.encoders): - # NOTE(xcsong): Before layer.forward - # shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2), - # shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2) - xs, _, new_att_cache, new_cnn_cache = layer( - xs, att_mask, pos_emb, - att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache, - cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache - ) - # NOTE(xcsong): After layer.forward - # shape(new_att_cache) is (1, head, attention_key_size, d_k * 2), - # shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2) - r_att_cache.append(new_att_cache[:, :, next_cache_start:, :]) - r_cnn_cache.append(new_cnn_cache.unsqueeze(0)) + x: torch.Tensor, + mask: torch.Tensor, + pos_emb: torch.Tensor, + mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), + att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # self attention module + residual = x if self.normalize_before: - xs = self.after_norm(xs) - - # NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2), - # ? may be larger than cache_t1, it depends on required_cache_size - r_att_cache = torch.cat(r_att_cache, dim=0) - # NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2) - r_cnn_cache = torch.cat(r_cnn_cache, dim=0) - - return (xs, r_att_cache, r_cnn_cache) - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - decoding_chunk_size: int, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ Forward input chunk by chunk with chunk_size like a streaming - fashion - - Here we should pay special attention to computation cache in the - streaming style forward chunk by chunk. Three things should be taken - into account for computation in the current network: - 1. transformer/conformer encoder layers output cache - 2. convolution in conformer - 3. convolution in subsampling - - However, we don't implement subsampling cache for: - 1. We can control subsampling module to output the right result by - overlapping input instead of cache left context, even though it - wastes some computation, but subsampling only takes a very - small fraction of computation in the whole model. - 2. Typically, there are several covolution layers with subsampling - in subsampling module, it is tricky and complicated to do cache - with different convolution layers with different subsampling - rate. - 3. Currently, nn.Sequential is used to stack all the convolution - layers in subsampling, we need to rewrite it to make it work - with cache, which is not prefered. - Args: - xs (torch.Tensor): (1, max_len, dim) - chunk_size (int): decoding chunk size - """ - assert decoding_chunk_size > 0 - # The model is trained by static or dynamic chunk - assert self.static_chunk_size > 0 or self.use_dynamic_chunk - subsampling = self.embed.subsampling_rate - context = self.embed.right_context + 1 # Add current frame - stride = subsampling * decoding_chunk_size - decoding_window = (decoding_chunk_size - 1) * subsampling + context - num_frames = xs.size(1) - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device) - outputs = [] - offset = 0 - required_cache_size = decoding_chunk_size * num_decoding_left_chunks + x = self.layer_norm1(x) + x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, att_cache) + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.layer_norm1(x) - # Feed forward overlap input step by step - for cur in range(0, num_frames - context + 1, stride): - end = min(cur + decoding_window, num_frames) - chunk_xs = xs[:, cur:end, :] - (y, att_cache, cnn_cache) = self.forward_chunk( - chunk_xs, offset, required_cache_size, att_cache, cnn_cache) - outputs.append(y) - offset += y.size(1) - ys = torch.cat(outputs, 1) - masks = torch.ones((1, 1, ys.size(1)), device=ys.device, dtype=torch.bool) - return ys, masks + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm2(x) + x = self.ffn1(x) + # we do not use dropout here since it is inside feed forward function + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm2(x) + + # conv module + new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) + residual = x + if self.normalize_before: + x = self.layer_norm3(x) + x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm3(x) + + # ffn module + residual = x + if self.normalize_before: + x = self.layer_norm4(x) + x = self.ffn2(x) + # we do not use dropout here since it is inside feed forward function + x = residual + self.dropout(x) + if not self.normalize_before: + x = self.layer_norm4(x) + + return x, mask, new_att_cache, new_cnn_cache From d77730592d2744bb49c07e8aa3b6308f40d7aef6 Mon Sep 17 00:00:00 2001 From: yygle Date: Sun, 9 Oct 2022 22:54:24 +0800 Subject: [PATCH 26/29] fix formatting issue --- wenet/squeezeformer/encoder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/wenet/squeezeformer/encoder.py b/wenet/squeezeformer/encoder.py index ad4ceec47..6c7b3cd91 100644 --- a/wenet/squeezeformer/encoder.py +++ b/wenet/squeezeformer/encoder.py @@ -42,6 +42,7 @@ def __init__( reduce_idx: int = 5, recover_idx: int = 11, feed_forward_expansion_factor: int = 4, + dw_stride: bool = False, input_dropout_rate: float = 0.1, pos_enc_layer_type: str = "rel_pos", time_reduction_layer_type: str = "conv1d", @@ -73,6 +74,8 @@ def __init__( reduce_idx (int): reduce layer index, from 40ms to 80ms per frame. recover_idx (int): recover layer index, from 80ms to 40ms per frame. feed_forward_expansion_factor (int): Enlarge coefficient of FFN. + dw_stride (bool): Whether do depthwise convolution + on subsampling module. input_dropout_rate (float): Dropout rate of input projection layer. pos_enc_layer_type (str): Self attention type. time_reduction_layer_type (str): Conv1d or Conv2d reduction layer. @@ -134,7 +137,8 @@ def __init__( self.embed = DepthwiseConv2dSubsampling4( 1, encoder_dim, - RelPositionalEncoding(encoder_dim, dropout_rate=0.1) + RelPositionalEncoding(encoder_dim, dropout_rate=0.1), + dw_stride ) self.input_proj = nn.Sequential( nn.Linear( From cd82d89b645b12e3530c40593d5ac2b4fe957e88 Mon Sep 17 00:00:00 2001 From: yygle Date: Wed, 12 Oct 2022 19:06:51 +0800 Subject: [PATCH 27/29] [update] SqueezeFormer Large Results --- examples/librispeech/s0/README.md | 126 ++++++++++++ .../conf/train_squeezeformer.yaml | 0 .../train_squeezeformer_bidecoder_large.yaml | 96 ++++++++++ examples/librispeech/squeezeformer/README.md | 179 ------------------ examples/librispeech/squeezeformer/local | 1 - examples/librispeech/squeezeformer/tools | 1 - examples/librispeech/squeezeformer/wenet | 1 - wenet/squeezeformer/encoder_layer.py | 1 - 8 files changed, 222 insertions(+), 183 deletions(-) rename examples/librispeech/{squeezeformer => s0}/conf/train_squeezeformer.yaml (100%) create mode 100644 examples/librispeech/s0/conf/train_squeezeformer_bidecoder_large.yaml delete mode 100644 examples/librispeech/squeezeformer/README.md delete mode 120000 examples/librispeech/squeezeformer/local delete mode 120000 examples/librispeech/squeezeformer/tools delete mode 120000 examples/librispeech/squeezeformer/wenet diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 95b697263..c227b90c8 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -2,6 +2,7 @@ ## Conformer Result Bidecoder (large) +* Encoder FLOPs(30s): 96,238,430,720, params: 85,709,704 * Feature info: using fbank feature, cmvn, dither, online speed perturb * Training info: train_conformer_bidecoder_large.yaml, kernel size 31, lr 0.002, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 * Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 @@ -18,8 +19,31 @@ | LM-tglarge + attention rescoring | 2.68 | 6.10 | | LM-fglarge + attention rescoring | 2.65 | 5.98 | +## SqueezeFormer Result (U2++, FFN:2048) + +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv1d, batch_norm, syncbn + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * Encoder FLOPs(30s): 82,283,704,832, params: 85,984,648 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb, spec_aug +* Training info: + * train_squeezeformer_bidecoder_large.yaml, kernel size 31 + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 + * adamw, lr 8e-4, NoamHold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 2.62 | 6.80 | 2.92 | 6.77 | +| ctc prefix beam search | 2.60 | 6.79 | 2.90 | 6.79 | +| attention decoder | 3.06 | 6.90 | 3.38 | 6.82 | +| attention rescoring | 2.33 | 6.29 | 2.57 | 6.22 | + ## Conformer Result +* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 * Feature info: using fbank feature, cmvn, dither, online speed perturb * Training info: train_conformer.yaml, kernel size 31, lr 0.004, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 * Decoding info: ctc_weight 0.5, average_num 30 @@ -35,6 +59,82 @@ | attention rescoring (beam 50) | 3.12 | 8.55 | | LM-fglarge + attention rescoring | 3.09 | 7.40 | +## Conformer Result (12 layers, FFN:2048) +* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: train_squeezeformer.yaml, kernel size 31, +* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 +* AdamW, lr 1e-3, NoamHold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.49 | 9.59 | 3.66 | 9.59 | +| ctc prefix beam search | 3.49 | 9.61 | 3.66 | 9.55 | +| attention decoder | 3.52 | 9.04 | 3.85 | 8.97 | +| attention rescoring | 3.10 | 8.91 | 3.29 | 8.81 | + +## SqueezeFormer Result (SM12, FFN:1024) +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*4=1024 + * Encoder FLOPs(30s): 21,158,877,440, params: 22,219,912 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31, + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 + * adamw, lr=1e-3, noamhold, warmup=0.2, hold=0.3, lr_decay=1.0 +* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 | +| ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 | +| attention decoder | 3.59 | 8.74 | 3.75 | 8.70 | +| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | + +## SqueezeFormer Result (SM12, FFN:2048) +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 +* Feature info: using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31 + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 + * adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.34 | 9.01 | 3.47 | 8.85 | +| ctc prefix beam search | 3.33 | 9.02 | 3.46 | 8.81 | +| attention decoder | 3.64 | 8.62 | 3.91 | 8.33 | +| attention rescoring | 2.89 | 8.34 | 3.10 | 8.03 | + +## SqueezeFormer Result (SM12, FFN:1312) +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv1d, w/o syncbn + * encoder_dim 328, output_size 256, head 4, ffn_dim 328*4=1312 + * encoder FLOPs(30s): 34,103,960,008, params: 35,678,352 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31, + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 + * adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + +| decoding mode | dev clean | dev other | test clean | test other | +|----------------------------------|-----------|-----------|------------|------------| +| ctc greedy search | 3.20 | 8.46 | 3.30 | 8.58 | +| ctc prefix beam search | 3.18 | 8.44 | 3.30 | 8.55 | +| attention decoder | 3.38 | 8.31 | 3.89 | 8.32 | +| attention rescoring | 2.81 | 7.86 | 2.96 | 7.91 | + ## Conformer U2++ Result * Feature info: using fbank feature, cmvn, no speed perturb, dither @@ -43,17 +143,41 @@ * Git hash: 65270043fc8c2476d1ab95e7c39f730017a670e0 test clean + | decoding mode | full | 16 | |--------------------------------|------|------| | ctc prefix beam search | 3.76 | 4.54 | | attention rescoring | 3.32 | 3.80 | test other + | decoding mode | full | 16 | |--------------------------------|-------|-------| | ctc prefix beam search | 9.50 | 11.52 | | attention rescoring | 8.67 | 10.38 | +## SqueezeFormer Result (U2++, FFN:2048) + +* Encoder info: + * SM12, reduce_idx 5, recover_idx 11, conv1d, layer_norm, do_rel_shift false + * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * Encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 +* Feature info: + * using fbank feature, cmvn, dither, online speed perturb +* Training info: + * train_squeezeformer.yaml, kernel size 31 + * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 + * adamw, lr 1e-3, NoamHold, warmup 0.1, hold 0.4, lr_decay 1.0 +* Decoding info: + * ctc_weight 0.3, reverse weight 0.5, average_num 30 + +test clean + +| decoding mode | full | 16 | +|--------------------------------|------|------| +| ctc prefix beam search | 3.81 | 4.59 | +| attention rescoring | 3.36 | 3.93 | + ## Conformer U2 Result * Feature info: using fbank feature, cmvn, speed perturb, dither @@ -65,6 +189,7 @@ test other * LM-fglarge: [4-gram.arpa.gz](http://www.openslr.org/resources/11/4-gram.arpa.gz) test clean + | decoding mode | full | 16 | |----------------------------------|------|------| | ctc prefix beam search | 4.26 | 5.00 | @@ -76,6 +201,7 @@ test clean | LM-fglarge + attention rescoring | 3.38 | 3.74 | test other + | decoding mode | full | 16 | |----------------------------------|-------|-------| | ctc prefix beam search | 10.87 | 12.87 | diff --git a/examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml b/examples/librispeech/s0/conf/train_squeezeformer.yaml similarity index 100% rename from examples/librispeech/squeezeformer/conf/train_squeezeformer.yaml rename to examples/librispeech/s0/conf/train_squeezeformer.yaml diff --git a/examples/librispeech/s0/conf/train_squeezeformer_bidecoder_large.yaml b/examples/librispeech/s0/conf/train_squeezeformer_bidecoder_large.yaml new file mode 100644 index 000000000..3bac09e6e --- /dev/null +++ b/examples/librispeech/s0/conf/train_squeezeformer_bidecoder_large.yaml @@ -0,0 +1,96 @@ +# network architecture +# encoder related +encoder: squeezeformer +encoder_conf: + encoder_dim: 512 + output_size: 512 # dimension of attention + attention_heads: 8 + num_blocks: 12 # the number of encoder blocks + reduce_idx: 5 + recover_idx: 11 + feed_forward_expansion_factor: 4 + input_dropout_rate: 0.1 + feed_forward_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + cnn_module_kernel: 31 + cnn_norm_type: batch_norm + adaptive_scale: true + normalize_before: false + +# decoder related +decoder: bitransformer +decoder_conf: + attention_heads: 8 + linear_units: 2048 + num_blocks: 3 + r_num_blocks: 3 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + +# hybrid CTC/attention +model_conf: + ctc_weight: 0.3 + lsm_weight: 0.1 # label smoothing option + length_normalized_loss: false + reverse_weight: 0.3 + +# dataset related +dataset_conf: + syncbn: true + filter_conf: + max_length: 2000 + min_length: 50 + token_max_length: 400 + token_min_length: 1 + min_output_input_ratio: 0.0005 + max_output_input_ratio: 0.1 + resample_conf: + resample_rate: 16000 + speed_perturb: true + fbank_conf: + num_mel_bins: 80 + frame_shift: 10 + frame_length: 25 + dither: 1.0 + spec_aug: true + spec_aug_conf: + num_t_mask: 3 + num_f_mask: 2 + max_t: 100 + max_f: 27 + max_w: 80 +# warp_for_time: true + spec_sub: true + spec_sub_conf: + num_t_sub: 3 + max_t: 30 + shuffle: true + shuffle_conf: + shuffle_size: 1500 + sort: true + sort_conf: + sort_size: 500 # sort_size should be less than shuffle_size + batch_conf: + batch_type: 'static' # static or dynamic + batch_size: 12 + +grad_clip: 5 +accum_grad: 4 +max_epoch: 500 +log_interval: 100 + +optim: adamw +optim_conf: + lr: 1.e-3 + weight_decay: 4.e-5 + +scheduler: NoamHoldAnnealing +scheduler_conf: + warmup_ratio: 0.2 + hold_ratio: 0.3 + max_steps: 87960 + decay_rate: 1.0 + min_lr: 1.e-5 + diff --git a/examples/librispeech/squeezeformer/README.md b/examples/librispeech/squeezeformer/README.md deleted file mode 100644 index 322bac803..000000000 --- a/examples/librispeech/squeezeformer/README.md +++ /dev/null @@ -1,179 +0,0 @@ -# Develop Record - -``` -squeezeformer -├── attention.py # reltive multi-head attention module -├── conv2d.py # self defined conv2d valid padding module -├── convolution.py # convolution module in squeezeformer block -├── encoder_layer.py # squeezeformer encoder layer -├── encoder.py # squeezeformer encoder class -├── positionwise_feed_forward.py # feed forward layer -├── subsampling.py # sub-sampling layer, time reduction layer -└── utils.py # residual connection module -``` - -* Implementation Details - * Squeezeformer Encoder - * [x] add pre layer norm before squeezeformer block - * [x] derive time reduction layer from tensorflow version - * [x] enable adaptive scale operation - * [x] enable init weights for deep model training - * [x] enable training config and results - * [x] enable dynamic chunk and JIT export - * Training - * [x] enable NoamHoldAnnealing schedular - -# Performance Record - -### Conformer -* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 -* Feature info: using fbank feature, cmvn, dither, online speed perturb -* Training info: train_conformer.yaml, kernel size 31, lr 0.004, -* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 -* Decoding info: ctc_weight 0.5, average_num 30 - -| decoding mode | test clean | test other | -|----------------------------------|------------|------------| -| ctc greedy search | 3.51 | 9.57 | -| ctc prefix beam search | 3.51 | 9.56 | -| attention decoder | 3.05 | 8.36 | -| attention rescoring | 3.18 | 8.72 | - -### Conformer Result (12 layers, FFN:2048) -* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 -* Feature info: using fbank feature, cmvn, dither, online speed perturb -* Training info: train_squeezeformer.yaml, kernel size 31, -* batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 -* AdamW, lr 1e-3, NoamHold, warmup 0.2, hold 0.3, lr_decay 1.0 -* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 - -| decoding mode | dev clean | dev other | test clean | test other | -|----------------------------------|-----------|-----------|------------|------------| -| ctc greedy search | 3.49 | 9.59 | 3.66 | 9.59 | -| ctc prefix beam search | 3.49 | 9.61 | 3.66 | 9.55 | -| attention decoder | 3.52 | 9.04 | 3.85 | 8.97 | -| attention rescoring | 3.10 | 8.91 | 3.29 | 8.81 | - -### Squeezeformer Result (SM12, FFN:1024) -* Encoder info: - * SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn - * encoder_dim 256, output_size 256, head 4, ffn_dim 256*4=1024 - * Encoder FLOPs(30s): 21,158,877,440, params: 22,219,912 -* Feature info: - * using fbank feature, cmvn, dither, online speed perturb -* Training info: - * train_squeezeformer.yaml, kernel size 31, - * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 - * adamw, lr=1e-3, noamhold, warmup=0.2, hold=0.3, lr_decay=1.0 -* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 - -| decoding mode | dev clean | dev other | test clean | test other | -|----------------------------------|-----------|-----------|------------|------------| -| ctc greedy search | 3.49 | 9.24 | 3.51 | 9.28 | -| ctc prefix beam search | 3.44 | 9.23 | 3.51 | 9.25 | -| attention decoder | 3.59 | 8.74 | 3.75 | 8.70 | -| attention rescoring | 2.97 | 8.48 | 3.07 | 8.44 | - -### Squeezeformer Result (SM12, FFN:2048) -* Encoder info: - * SM12, reduce_idx 5, recover_idx 11, conv2d, w/o syncbn - * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 - * encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 -* Feature info: using fbank feature, cmvn, dither, online speed perturb -* Training info: - * train_squeezeformer.yaml, kernel size 31 - * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 - * adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 -* Decoding info: - * ctc_weight 0.3, reverse weight 0.5, average_num 30 - -| decoding mode | dev clean | dev other | test clean | test other | -|----------------------------------|-----------|-----------|------------|------------| -| ctc greedy search | 3.34 | 9.01 | 3.47 | 8.85 | -| ctc prefix beam search | 3.33 | 9.02 | 3.46 | 8.81 | -| attention decoder | 3.64 | 8.62 | 3.91 | 8.33 | -| attention rescoring | 2.89 | 8.34 | 3.10 | 8.03 | - -### Squeezeformer Result (SM12, FFN:1312) -* Encoder info: - * SM12, reduce_idx 5, recover_idx 11, conv1d, w/o syncbn - * encoder_dim 328, output_size 256, head 4, ffn_dim 328*4=1312 - * encoder FLOPs(30s): 34,103,960,008, params: 35,678,352 -* Feature info: - * using fbank feature, cmvn, dither, online speed perturb -* Training info: - * train_squeezeformer.yaml, kernel size 31, - * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 - * adamw, lr 1e-3, noamhold, warmup 0.2, hold 0.3, lr_decay 1.0 -* Decoding info: - * ctc_weight 0.3, reverse weight 0.5, average_num 30 - -| decoding mode | dev clean | dev other | test clean | test other | -|----------------------------------|-----------|-----------|------------|------------| -| ctc greedy search | 3.20 | 8.46 | 3.30 | 8.58 | -| ctc prefix beam search | 3.18 | 8.44 | 3.30 | 8.55 | -| attention decoder | 3.38 | 8.31 | 3.89 | 8.32 | -| attention rescoring | 2.81 | 7.86 | 2.96 | 7.91 | - -### Conformer U2++ Result - -* Encoder FLOPs(30s): 34,085,088,512, params: 34,761,608 -* Feature info: using fbank feature, cmvn, no speed perturb, dither -* Training info: train_u2++_conformer.yaml lr 0.001, batch size 24, 8 gpu, acc_grad 1, 120 epochs, dither 1.0 -* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 -* Git hash: 65270043fc8c2476d1ab95e7c39f730017a670e0 - -test clean - -| decoding mode | full | 16 | -|--------------------------------|------|------| -| ctc prefix beam search | 3.76 | 4.54 | -| attention rescoring | 3.32 | 3.80 | - -test other - -| decoding mode | full | 16 | -|--------------------------------|-------|-------| -| ctc prefix beam search | 9.50 | 11.52 | -| attention rescoring | 8.67 | 10.38 | - -### Squeezeformer Result (U2++, FFN:2048) - -* Encoder info: - * SM12, reduce_idx 5, recover_idx 11, conv1d, layer_norm, do_rel_shift false - * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 - * Encoder FLOPs(30s): 28,230,473,984, params: 34,827,400 -* Feature info: - * using fbank feature, cmvn, dither, online speed perturb -* Training info: - * train_squeezeformer.yaml, kernel size 31 - * batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 0.1 - * adamw, lr 1e-3, NoamHold, warmup 0.1, hold 0.4, lr_decay 1.0 -* Decoding info: - * ctc_weight 0.3, reverse weight 0.5, average_num 30 - -test clean - -| decoding mode | full | 16 | -|--------------------------------|------|------| -| ctc prefix beam search | 3.81 | 4.59 | -| attention rescoring | 3.36 | 3.93 | - -test other - -| decoding mode | full | 16 | -|--------------------------------|-------|-------| -| ctc prefix beam search | 9.12 | 11.17 | -| attention rescoring | 8.43 | 10.21 | - -### Conformer Result Bidecoder (large) - -* Encoder FLOPs(30s): 96,238,430,720, params: 85,709,704 -* Feature info: using fbank feature, cmvn, dither, online speed perturb -* Training info: train_conformer_bidecoder_large.yaml, kernel size 31, lr 0.002, batch size 12, 8 gpu, acc_grad 4, 120 epochs, dither 1.0 -* Decoding info: ctc_weight 0.3, reverse weight 0.5, average_num 30 - -| decoding mode | test clean | test other | -|----------------------------------|------------|------------| -| ctc prefix beam search | 2.96 | 7.14 | -| attention rescoring | 2.66 | 6.53 | diff --git a/examples/librispeech/squeezeformer/local b/examples/librispeech/squeezeformer/local deleted file mode 120000 index ea4a20415..000000000 --- a/examples/librispeech/squeezeformer/local +++ /dev/null @@ -1 +0,0 @@ -../s0/local \ No newline at end of file diff --git a/examples/librispeech/squeezeformer/tools b/examples/librispeech/squeezeformer/tools deleted file mode 120000 index c92f4172d..000000000 --- a/examples/librispeech/squeezeformer/tools +++ /dev/null @@ -1 +0,0 @@ -../../../tools \ No newline at end of file diff --git a/examples/librispeech/squeezeformer/wenet b/examples/librispeech/squeezeformer/wenet deleted file mode 120000 index 702de77db..000000000 --- a/examples/librispeech/squeezeformer/wenet +++ /dev/null @@ -1 +0,0 @@ -../../../wenet \ No newline at end of file diff --git a/wenet/squeezeformer/encoder_layer.py b/wenet/squeezeformer/encoder_layer.py index 5d848ec63..3c6bdd44a 100644 --- a/wenet/squeezeformer/encoder_layer.py +++ b/wenet/squeezeformer/encoder_layer.py @@ -94,7 +94,6 @@ def forward( if self.normalize_before: x = self.layer_norm2(x) x = self.ffn1(x) - # we do not use dropout here since it is inside feed forward function x = residual + self.dropout(x) if not self.normalize_before: x = self.layer_norm2(x) From 3c55dde0d7fa8024c79a90c66120a2a118fded88 Mon Sep 17 00:00:00 2001 From: yygle Date: Wed, 12 Oct 2022 19:20:13 +0800 Subject: [PATCH 28/29] fix formatting issues --- examples/librispeech/s0/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index c227b90c8..94465bd68 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -23,7 +23,7 @@ * Encoder info: * SM12, reduce_idx 5, recover_idx 11, conv1d, batch_norm, syncbn - * encoder_dim 256, output_size 256, head 4, ffn_dim 256*8=2048 + * encoder_dim 512, output_size 512, head 8, ffn_dim 512*4=2048 * Encoder FLOPs(30s): 82,283,704,832, params: 85,984,648 * Feature info: * using fbank feature, cmvn, dither, online speed perturb, spec_aug From 0824e56177fc70e83b50b415a2a924b3aeba5b6d Mon Sep 17 00:00:00 2001 From: yygle Date: Fri, 14 Oct 2022 11:37:03 +0800 Subject: [PATCH 29/29] fix format issues --- examples/librispeech/s0/README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/librispeech/s0/README.md b/examples/librispeech/s0/README.md index 94465bd68..4892ddb4a 100644 --- a/examples/librispeech/s0/README.md +++ b/examples/librispeech/s0/README.md @@ -178,6 +178,13 @@ test clean | ctc prefix beam search | 3.81 | 4.59 | | attention rescoring | 3.36 | 3.93 | +test other + +| decoding mode | full | 16 | +|--------------------------------|-------|-------| +| ctc prefix beam search | 9.12 | 11.17 | +| attention rescoring | 8.43 | 10.21 | + ## Conformer U2 Result * Feature info: using fbank feature, cmvn, speed perturb, dither