Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNNT模型在executer执行cv操作的时候提示KeyError:“th_accuracy” #2336

Closed
DaobinZhu opened this issue Feb 2, 2024 · 3 comments

Comments

@DaobinZhu
Copy link
Contributor

训练完毕后执行cv操作的时候报错,但是看代码感觉没有问题
2024-02-02 11:31:51,367 DEBUG TRAIN Batch 0/7500 loss 129.710220 loss_att 106.460495 loss_ctc 154.101700 loss_rnnt 131.107971 lr 0.00030000 grad_norm 27.183880 rank 1
2024-02-02 11:31:51,367 DEBUG TRAIN Batch 0/7500 loss 149.184647 loss_att 124.107864 loss_ctc 157.846741 loss_rnnt 153.045059 lr 0.00030000 grad_norm 27.183880 rank 2
2024-02-02 11:31:51,367 DEBUG TRAIN Batch 0/7500 loss 140.741501 loss_att 112.667061 loss_ctc 158.834442 loss_rnnt 143.944000 lr 0.00030000 grad_norm 27.183880 rank 3
2024-02-02 11:31:51,368 DEBUG TRAIN Batch 0/7500 loss 129.564987 loss_att 106.934181 loss_ctc 157.959274 loss_rnnt 130.305237 lr 0.00030000 grad_norm 27.183880 rank 0
Traceback (most recent call last):
File "wenet/bin/train.py", line 175, in
main()
File "/home/lsj/.conda/envs/rnnt/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 346, in wrapper
return f(*args, **kwargs)
File "wenet/bin/train.py", line 147, in main
loss_dict = executor.cv(model, cv_data_loader, configs)
File "/home/lsj/zdb/biye/wenet/wenet/utils/executor.py", line 123, in cv
) if _dict['th_accuracy'] is not None else 0.0)
KeyError: 'th_accuracy'
Traceback (most recent call last):
File "wenet/bin/train.py", line 175, in
main()
File "/home/lsj/.conda/envs/rnnt/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/init.py", line 346, in wrapper
return f(*args, **kwargs)
File "wenet/bin/train.py", line 147, in main
loss_dict = executor.cv(model, cv_data_loader, configs)
File "/home/lsj/zdb/biye/wenet/wenet/utils/executor.py", line 123, in cv
) if _dict['th_accuracy'] is not None else 0.0)
KeyError: 'th_accuracy'

@xingchensong
Copy link
Member

可以在这里的dict中加一下th_accuracy
https://github.com/wenet-e2e/wenet/blob/main/wenet/transducer/transducer.py#L144-L148

修复后可以提个PR

@DaobinZhu
Copy link
Contributor Author

可以在这里的dict中加一下th_accuracy https://github.com/wenet-e2e/wenet/blob/main/wenet/transducer/transducer.py#L144-L148

修复后可以提个PR

好的,感谢大佬

@DaobinZhu
Copy link
Contributor Author

犯了个低级错误,居然在看asr_model.py,已经提交PR了 #2337

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants