-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
python/mxnet/gluon/loss.py
Outdated
@@ -409,6 +409,8 @@ class CTCLoss(Loss): | |||
length respectively. | |||
weight : float or None | |||
Global scalar weight for loss. | |||
blank_label : {'first', 'last'}, default 'last' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was intentional not to expose this option in gluon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That means I need to revert this commit and resend a pull request ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can simply add a commit that removes the change in blank_label
Is the speed of big vocab size solved of this op? |
@chinakook its looks good to me |
This reverts commit aab11f7.
@Jerryzcn could you give this a try? |
will do |
Thanks @HawkAaron. it seems removing the max subtraction negatively affect convergence. Do you observe similar result? In the original baidu ctc, there is a section that do max subtraction. I suspect the broadcast is too slow here. Maybe we should write a function for this.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
subtracting max is missing
denoms_handle = reduce_with_axis<red::sum, false>( | ||
F<mxnet::op::mshadow_op::exp>( | ||
log_probs_handle - | ||
broadcast<0>(reduce_with_axis<red::maximum, false>(log_probs_handle, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max is necessary here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the max is reduced in here: https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/ctc_include/detail/gpu_ctc.h#L398
and log_probs -= denoms
is in here: https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/ctc_include/detail/gpu_ctc.h#L409
that means the broadcast line (https://github.com/apache/incubator-mxnet/blob/master/src/operator/contrib/ctc_include/detail/gpu_ctc.h#L417) should be zero
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Did not look at other parts of the code thanks! Found a bug on my end.
@HawkAaron thanks for the fix, and @Jerryzcn thanks for the review |
* fix ctc_loss GPU bug * add blank_label parameter for CTCLoss * Revert "add blank_label parameter for CTCLoss" This reverts commit aab11f7.
2. add blank label for gluon CTCLossedit @szha: crossed out second item