Skip to content

Commit

Permalink
[Efficient Conformer] remove concat after to simplify the code flow (#…
Browse files Browse the repository at this point in the history
…1764)

* add Efficient Conformer implementation

* fix trailing whitespace, formatting and semantic

* Ensures consistency of forward_chunk interface and deletes all runtime changes. Completed the casual and non-casual convolution model tests for the EfficientConformer, as well as JIT runtime tests. Modified yaml files for Aishell-1

* [EfficientConformer] add Aishell-1 Results

* [EfficientConformer] support ONNX GPU export, add librispeech results, and fix bug of V2 streaming decode

* [Efficient Conformer] add model params in README.

* fix trailing whitespace

* [Efficient Conformer] remove concat after to simplify the code flow
  • Loading branch information
zwglory authored Mar 20, 2023
1 parent d3fd9b9 commit f5b9b3e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 15 deletions.
3 changes: 0 additions & 3 deletions wenet/efficient_conformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(
input_layer: str = "conv2d",
pos_enc_layer_type: str = "rel_pos",
normalize_before: bool = True,
concat_after: bool = False,
static_chunk_size: int = 0,
use_dynamic_chunk: bool = False,
global_cmvn: torch.nn.Module = None,
Expand Down Expand Up @@ -221,7 +220,6 @@ def __init__(
count_include_pad=False), # pointwise_conv_layer
dropout_rate,
normalize_before,
concat_after,
))
index = index + 1
else:
Expand All @@ -239,7 +237,6 @@ def __init__(
*convolution_layer_args_normal) if use_cnn_module else None,
dropout_rate,
normalize_before,
concat_after,
))

self.encoders = torch.nn.ModuleList(layers)
Expand Down
14 changes: 2 additions & 12 deletions wenet/efficient_conformer/encoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class StrideConformerEncoderLayer(nn.Module):
normalize_before (bool):
True: use layer_norm before each sub-block.
False: use layer_norm after each sub-block.
concat_after (bool): Whether to concat attention layer's input and
output.
True: x -> x + linear(concat(x, att(x)))
False: x -> x + att(x)
"""
def __init__(
self,
Expand All @@ -54,8 +50,7 @@ def __init__(
conv_module: Optional[nn.Module] = None,
pointwise_conv_layer: Optional[nn.Module] = None,
dropout_rate: float = 0.1,
normalize_before: bool = True,
concat_after: bool = False,
normalize_before: bool = True
):
"""Construct an EncoderLayer object."""
super().__init__()
Expand All @@ -79,7 +74,6 @@ def __init__(
self.dropout = nn.Dropout(dropout_rate)
self.size = size
self.normalize_before = normalize_before
self.concat_after = concat_after
self.concat_linear = nn.Linear(size + size, size)

def forward(
Expand Down Expand Up @@ -131,11 +125,7 @@ def forward(
x_att, new_att_cache = self.self_attn(
x, x, x, mask, pos_emb, att_cache)

if self.concat_after:
x_concat = torch.cat((x, x_att), dim=-1)
x = residual + self.concat_linear(x_concat)
else:
x = residual + self.dropout(x_att)
x = residual + self.dropout(x_att)
if not self.normalize_before:
x = self.norm_mha(x)

Expand Down

0 comments on commit f5b9b3e

Please sign in to comment.