Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Squeezeformer #1447

Merged
merged 29 commits into from
Oct 18, 2022
Merged

Squeezeformer #1447

merged 29 commits into from
Oct 18, 2022

Conversation

yygle
Copy link
Contributor

@yygle yygle commented Sep 15, 2022

Develop Record

squeezeformer
├── attention.py                        # reltive multi-head attention module  
├── conv2d.py                           # self defined conv2d valid padding module
├── convolution.py                      # convolution module in squeezeformer block
├── encoder_layer.py                    # squeezeformer encoder layer
├── encoder.py                          # squeezeformer encoder class 
├── positionwise_feed_forward.py        # feed forward layer 
├── subsampling.py                      # sub-sampling layer, time reduction layer
└── utils.py                            # residual connection module
  • Implementation Details
    • Squeezeformer Encoder
      • add pre layer norm before squeezeformer block
      • derive time reduction layer from tensorflow version
      • enable adaptive scale operation
      • enable init weights for deep model training
      • enable training config and results
      • enable dynamic chunk and JIT export
    • Training
      • enable NoamHoldAnnealing schedular



class ResidualModule(nn.Module):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can move it to encoder.py, utils.py is not a proper name.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we just do it directly in forward function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved residual connection to encoder.py and removed the util.py.

@@ -0,0 +1,172 @@
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please note you should change the company and author.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or append

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyrights is added.

@@ -0,0 +1 @@
../s0/local
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

squeezeformer in examples/librispeech is not required here, we can just do it in s0 by configue since it shares same training and decoding recipe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, what about remove it after the whole README.md of squeezeformer part is updated?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!


# dataset related
dataset_conf:
syncbn: true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与这个相关的代码似乎没有提上来?syncbn的转换似乎是可以在Train.py中调用torch api一键完成:

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

这里把他放到dataset_conf域是处于什么考量呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synbn不能直接在wenet中实现主要是因为数据不均衡带来的进程等待,完整的实现中我考虑了两种情况,1. 即DDP数据不分割(每个进程更新完整数据集),以及2. 分割数据集,drop掉多余部分,因此在我这个版本实现中,将这个变量与数据集绑定在了一起。 这个部分的代码因为与Squeezeformer的算法更新无关,属于工程优化范畴,因此会另提交PR更新。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,那我先合并,你继续优化和迭代。

@robin1001
Copy link
Collaborator

Great job, looking forward to your future work and SOTA result.

@robin1001 robin1001 merged commit bbf844a into wenet-e2e:main Oct 18, 2022
)
self.input_proj = nn.Sequential(
nn.Linear(
encoder_dim * (((input_size - 1) // 2 - 1) // 2), encoder_dim),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这个为啥不放到DepthwiseConv2dSubsampling4? 拍脑洞想以后也许可以有更激进的U-NET结构(比如从10ms直接到80ms,第5层后激进到160ms,最后一层再回到40ms OR 80ms) 此时需要DepthwiseConv2dSubsampling8,145行需要相应修改下采样输出维度,从这个角度看,input_proj最好耦合到subsample里。:)
  2. forward_chunk没调用input_proj, 是想流式功能在后续PR中提交这里只是占位吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 这个部分确实大部分功能和conv2d4是一致的,不合并是因为当时讨论到不要影响conformer的使用,(input_proj最好耦合到subsample里)我也觉得这个方式更加合理,后面会更新。
    2.这个部分,包括流式的完整推理还不完善,近期更新

@xingchensong
Copy link
Member

Great Job! THX!

@robin1001 robin1001 mentioned this pull request Oct 18, 2022
11 tasks
if self.reduce_idx <= idx < self.recover_idx:
residual = xs
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
xs += residual
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的residual是不是和论文里的不一样?论文里好像只有reduce那一层和recover那一层进行residual连接。

Copy link
Contributor Author

@yygle yygle Oct 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个部分可能需要再核实一下,如果仅在那两个部分有残差对网络效果是否有影响,这个部分我看到的不同版本实现稍有差别

xs = self.final_proj(xs)
return xs, masks

def forward_chunk(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce跟recover之间降采样那些层的在流式的时候,chunk折半,attention的缓存是不是也折半了(比如chunk_size = 16, num_left_chunks = 4) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

非常好的建议!是的,这个地方需要折半,流式的功能还不完整,近期我更新上去。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为流式的时候chunk要折半,是不是训练的时候也需要限制chunk须是偶数?还有就是reduce和recover之间的缓存我是和其他层分开的,相当于有三个缓存atten_cache, reduce_atten_cache以及conv_cache,期待更优雅的方式。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

训练时候chunk为偶数的话,从训练和推理的一致性上看,确实更合理,但我目前的实验看下来貌似影响不大。
缓存这个部分确实需要3个tensor来存储,我也在构思这个部分,欢迎给出更好的解决方案。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用外部padding+内部slicing的方式, 应该可以把不同帧率层的cache合并到一起。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants