Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contrastive learning for unified models #2100

Merged
merged 5 commits into from
Nov 9, 2023
Merged

Conversation

YYTtyy
Copy link
Contributor

@YYTtyy YYTtyy commented Nov 2, 2023

Hi,
This PR is the implementation of the INTERSPEECH 2023 paper Enhancing the Unified Streaming and Non-streaming Model with Contrastive Learning
Arxiv:https://arxiv.org/abs/2306.00755

Details:
add joint training & contrastive loss for unified models (in ctl_model/asr_model_ctl.py)
add pure full-context mode forward (in ctl_model/encoder.py)
only return chunk size 1~25 for training (in ctl_model/mask.py)

The results on the AISHELL-1 dataset from the literature are as follows:
image

In addition, we conducted experiments on the in-house corpus, which contains 25000 hours of Mandarin speech data. The results show that our method makes consistent improvements on the larger dataset. This table shows the results on the test set.

Models Chunk=-1 Chunk=16
U2 model 17% 18.8%
Ours 16.7% 18.2%

@xingchensong
Copy link
Member

感谢 ! 几点小建议

  1. wenet/ctl_model/mask.py (左)似乎没有修改必要?看diff和 wenet/utils/mask.py (右) 其实是基本一样的(少了一个if分支)
    6ed102f9-9f28-41fb-b724-43a4bf5889a7

  2. wenet/ctl_model/encoder.py (左)可以直接继承 wenet/transformer/encoder.py (右), 因为看代码相当于在BaseEncoder这个类上追加了一个成员函数forward_full且没有其他删除操作。继承的方式可以参考 https://github.com/wenet-e2e/wenet/blob/main/wenet/paraformer/layers.py#L207-L298
    dc72a9e7-746e-4021-8799-850f78c2de7b

  3. 可以rebase一下代码,最近training pipeline做了修改,train.py的大部分代码都挪到了train_utils.py,这个PR的相应修改可以酌情挪到train_utils, refactor(deepspeed): Refine traning code #2055

@YYTtyy
Copy link
Contributor Author

YYTtyy commented Nov 3, 2023

好的没问题!感谢建议~

@xingchensong
Copy link
Member

great work ! pr is quite clear,期待后续推送文章!

@xingchensong xingchensong merged commit a114e39 into wenet-e2e:main Nov 9, 2023
@kobenaxie
Copy link
Contributor

有训练的loss曲线可以参考吗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants