-
Notifications
You must be signed in to change notification settings - Fork 530
[Numpy] MobileBERT SQuAD training cannot reproduce the previous results #1322
Comments
let's try to bisect where the problem occurred first. @zheyuye could you share which commit of gluonnlp and which version of mxnet you used for producing the above result? |
@sxjscience Can you share your rerun log as well? Also how long does it take for one single run? |
It takes 4 hours on a g4-12dn. The rerun log: https://gluonnlp-numpy-data.s3-us-west-2.amazonaws.com/debug/finetune_squad2.0_mobilebert_20200827.log You may check the |
If I mistake not, latest gluonnlp and mxnet 0810 were used |
Judging from the log, the reproduction version seems to converge too slowly. How about the other pre-trained models and Electra smalll wont take too long |
The others work fine. |
why the two logs have basically the same grad norm for the last few logging points but very different losses? |
i have tried mxnet 0810 and 0826, but none of them can reproduce the result. |
Zheyu's log says the run happened on 8/20, so maybe it was an earlier GluonNLP commit that was working. |
I think it may also be the random seed. We can find that the gradient norm is super-large for the early runs. I'm actually considering how to add proper gradient scaling to the MobileBERT and recheck the conversion script. @zheyuye Would you know which version of GluonNLP were you using when producing this log? |
Conversion script is here: https://github.com/dmlc/gluon-nlp/blob/master/scripts/conversion_toolkits/convert_mobilebert.sh. We will need to re-verify the conversion to see if there is any issues. |
The PR and branch related to this issue https://github.com/ZheyuYe/gluon-nlp/commits/batch says that it was based on d8b68c6. In fact, random seeds can be another potential factor, and in general I would use 10 or 28 as a seed instead of the default value 100. |
I think the gradient for the early iterations are unreasonably large and we may consider to investigate that. |
The problem is that no version can reproduce the result. |
I added a markdown table in the first post for summarizing our observations. Feel free to directly edit. |
@szha @szhengac I think we should check the conversion and the training script of MobileBERT again to solve the issue. Also, the problem is that MXNet is not reproducible even if we specify the random seed. We may change the defaults related to seeds as recommended in apache/mxnet#18987. |
In fact, I can confirm that the forward check introduced in gluon-nlp/scripts/conversion_toolkits/convert_mobilebert.py Lines 304 to 313 in 66e5e05
|
@szha @szhengac @zheyuye I noticed that the gradient of gluon-nlp/src/gluonnlp/models/mobilebert.py Lines 690 to 699 in 66e5e05
Minimal Reproducible Example MXNet Implementation: import mxnet as mx
mx.npx.set_np()
ctx = mx.gpu()
a = mx.np.ones((3, 3, 3), ctx=ctx)
mult = np.random.normal(0, 1, (3, 3, 3))
a.attach_grad()
with mx.autograd.record():
b = mx.np.pad(a[:, 1:], ((0, 0), (0, 1), (0, 0))) * mx.np.array(mult, ctx=ctx)
b = b.sum()
b.backward()
print(a.grad) Output:
Jax Implementation: from jax import grad
import jax.numpy as jnp
import numpy as np
mult = np.random.normal(0, 1, (3, 3, 3))
a = jnp.ones((3, 3, 3))
def f(x):
b = jnp.pad(x[:, 1:], ((0, 0), (0, 1), (0, 0))) * jnp.array(mult)
return b.sum()
print(grad(f)(a)) Output:
|
Discussed offline with @cassinixu , fixing the pad operator in the MXNet side requires some time. Meanwhile, a simple fix is to use gluon-nlp/src/gluonnlp/models/mobilebert.py Lines 690 to 699 in 66e5e05
@zheyuye Would you try this approach? |
Basically, we can use |
Confirm that apache/mxnet#19044 fixed the bug. Closing this issue. |
@szhengac Would you submit a PR to update the SQuAD v1 + SQuAD v2 results of MobileBERT? |
@zheyuye (FYI @szhengac )
I run the MobileBERT training on SQuAD again and the log is significantly different from the log reported in https://gluon-nlp-log.s3.amazonaws.com/squad_training_log/fintune_google_uncased_mobilebert_squad_2.0/finetune_squad2.0.log
To reproduce, just install the master-versino of GluonNLP and try the command in https://github.com/dmlc/gluon-nlp/blob/master/scripts/question_answering/commands/run_squad2_mobilebert.sh.
The text was updated successfully, but these errors were encountered: