-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Squeezeformer #1447
Squeezeformer #1447
Conversation
wenet/squeezeformer/utils.py
Outdated
|
||
|
||
class ResidualModule(nn.Module): | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can move it to encoder.py, utils.py
is not a proper name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we just do it directly in forward
function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved residual connection to encoder.py and removed the util.py.
@@ -0,0 +1,172 @@ | |||
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please note you should change the company
and author
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or append
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copyrights is added.
@@ -0,0 +1 @@ | |||
../s0/local |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
squeezeformer
in examples/librispeech is not required here, we can just do it in s0 by configue since it shares same training and decoding recipe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, what about remove it after the whole README.md
of squeezeformer part is updated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok!
|
||
# dataset related | ||
dataset_conf: | ||
syncbn: true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
与这个相关的代码似乎没有提上来?syncbn的转换似乎是可以在Train.py中调用torch api一键完成:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
这里把他放到dataset_conf域是处于什么考量呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
synbn不能直接在wenet中实现主要是因为数据不均衡带来的进程等待,完整的实现中我考虑了两种情况,1. 即DDP数据不分割(每个进程更新完整数据集),以及2. 分割数据集,drop掉多余部分,因此在我这个版本实现中,将这个变量与数据集绑定在了一起。 这个部分的代码因为与Squeezeformer的算法更新无关,属于工程优化范畴,因此会另提交PR更新。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok,那我先合并,你继续优化和迭代。
Great job, looking forward to your future work and SOTA result. |
) | ||
self.input_proj = nn.Sequential( | ||
nn.Linear( | ||
encoder_dim * (((input_size - 1) // 2 - 1) // 2), encoder_dim), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这个为啥不放到
DepthwiseConv2dSubsampling4
? 拍脑洞想以后也许可以有更激进的U-NET结构(比如从10ms直接到80ms,第5层后激进到160ms,最后一层再回到40ms OR 80ms) 此时需要DepthwiseConv2dSubsampling8,145行需要相应修改下采样输出维度,从这个角度看,input_proj最好耦合到subsample里。:) - forward_chunk没调用input_proj, 是想流式功能在后续PR中提交这里只是占位吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这个部分确实大部分功能和conv2d4是一致的,不合并是因为当时讨论到不要影响conformer的使用,(input_proj最好耦合到subsample里)我也觉得这个方式更加合理,后面会更新。
2.这个部分,包括流式的完整推理还不完善,近期更新
Great Job! THX! |
if self.reduce_idx <= idx < self.recover_idx: | ||
residual = xs | ||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) | ||
xs += residual |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的residual是不是和论文里的不一样?论文里好像只有reduce那一层和recover那一层进行residual连接。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个部分可能需要再核实一下,如果仅在那两个部分有残差对网络效果是否有影响,这个部分我看到的不同版本实现稍有差别
xs = self.final_proj(xs) | ||
return xs, masks | ||
|
||
def forward_chunk( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reduce跟recover之间降采样那些层的在流式的时候,chunk折半,attention的缓存是不是也折半了(比如chunk_size = 16, num_left_chunks = 4) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非常好的建议!是的,这个地方需要折半,流式的功能还不完整,近期我更新上去。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为流式的时候chunk要折半,是不是训练的时候也需要限制chunk须是偶数?还有就是reduce和recover之间的缓存我是和其他层分开的,相当于有三个缓存atten_cache, reduce_atten_cache以及conv_cache,期待更优雅的方式。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
训练时候chunk为偶数的话,从训练和推理的一致性上看,确实更合理,但我目前的实验看下来貌似影响不大。
缓存这个部分确实需要3个tensor来存储,我也在构思这个部分,欢迎给出更好的解决方案。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用外部padding+内部slicing的方式, 应该可以把不同帧率层的cache合并到一起。
Develop Record