-
Notifications
You must be signed in to change notification settings - Fork 638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
xformers ViT-B ImageNet MAE + Deepnorm training instability #219
Comments
thanks for the issue, just saw that, looking into it ! |
@jramapuram can you elaborate on your config, do you have Triton for instance ? Could you share a print(model) here, to be sure of what parts are actually instantiated ? After a quick look it seems that there could be a part where the gradients are not handled up to the same precision as torch, at least that's my #1 hypothesis |
many thanks for the detailed issue and code snippets, this is perfect |
if you're using Triton, could you test out installing a recent dev package ? |
also @jramapuram could you confirm that this is with torch AMP ? (fp16) |
cc @dianaml0, @fmassa, is that something that you've seen ? I remember @xwhan saw that at some point, but I thought that this was fixed. I just did a quick check in the triton code, and we're keeping the data type as fp32 in the softmax and layernorm case when AMP is activated, which should lead to a similar precision as pytorch (layernorm is a bit below). It looks like a vanishing gradient problem, and the parts here are very standard (MLP and scaled_dot_product attention), I'm wondering whether it could be somewhere else in the code, or if the timm ViT adds some parameter-less normalization for instance. I'm not seeing this on the Cifar example that we host edit: adding some more context and info |
@jramapuram the eps parameter for LayerNorm is not the same in between timm and xformers (1e-5 vs. 1e-6), it's a long shot but since your issue could be related to vanishing gradient, could explain. Fixing that |
Filling in details:
Instantiated model print to STDOUT: https://gist.github.com/jramapuram/d284e0f261d3fdb15c213dd929d272b9 |
I can repro the problem with the minimal microViT example actually (prior to the linked PRs), just need to wait long enough. Testing right now with the changes from the linked PRs |
seems fine with the updated eps @jramapuram, let me know if it fixes your issue ? |
Training now; will update here :) def update_ln_eps(module: nn.Module, new_eps: float):
"""Recurse and update LN eps with this value."""
from xformers.triton.layer_norm import FusedLayerNorm
if isinstance(module, torch.nn.modules.LayerNorm):
module.eps = new_eps
if isinstance(module, FusedLayerNorm):
module.epsilon = new_eps
for _, child in module.named_children():
update_ln_eps(child, new_eps) |
@blefaudeux : Unfortunately this has not seemed to fix it for me 😬 . Not sure if the scaling from microViT --> ViT-B ImageNet might be causing some issues that are not easily evident. With LN fix using function above: With Triton Commit d4c28fb (tried with and without triton For sanity I also tried again swapping back to TIMM and it is still working 😬 |
ouch, this is not good.. the issue auto-closed it seems, but keeping it open, I'll try to dig a bit more |
@jramapuram to try to pinpoint this a little better (and if you have time), could you try in an environment which does not have Triton ? a few parts will default switch to PyTorch, if you don't see an issue there then I would know where to look (well, softmax and layernorm) |
Else I can think of
I can confirm that it does not happen on cifar and a smaller ViT unfortunately, would have been nice to have an easy repro edit: adding more context |
testing with pure pytorch layers right now, and I'm not seeing any difference so far, so might not be a good explanation |
init is different indeed, see for instance, while xformers mostly follows default pytorch
the projection seems to follow the same structure, n x 3n matrix + bias, nothing different here
nope, Pre-norm in both cases in short I don't see much difference (provided my home test with pytorch vs. triton parts is confirmed on your end @jramapuram) except for weights init, since AMP training is notoriously a little finicky maybe that could explain ? Not super intuitive to me but having a deeper look |
def _init_vit_weights(module: nn.Module):
"""Transformer weight initialization from TIMM."""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(module.bias)
nn.init.ones_(module.weight) @blefaudeux : are there any Thanks for the great suggestions btw! |
ahh, I didn't know for the init on your side, so this rules it out also !
No I don't think so, although fused MLP uses a normal nn.Linear but fuses the dropout/bias/activation (so the bias init would be missed). It does not seem like you're using fusedMLP so it should not be the case
No problem, this is a little perplexing to be honest, but we'll root it out ! |
hmm turns out I was testing with rotary embeddings turned on, and they make a huge difference |
thanks @jramapuram, it's very informative, so no issues with the triton layers whatsover, the problem is in a pure pytorch definition.. :/ |
Still no joy on ef6de0f 😬 . Here I only init just the |
oh yes for current main branch, nothing landed addressing this yet. Could you try #303 by any chance ? I can try to start something later today, but a little bit underwater atm :/ |
No worries! Will give that a shot :) [feel better!] I added the reference pre-norm graphs above. Differences are basically:
|
oh wow, it's pretty clear indeed, thanks @jramapuram. #303 is definitely fixing a small bug, but I doubt that it explains this really, I'll dive back into deepnorm. I may have a repro actually, with the recent metaformer+cifar10 deepnorm does not work either but I thought that was because of the decidely different model structure. I'll give it a second look, sorry for the delay |
hmm, I did spend some time on that and found nothing obviously wrong, it's really perplexing. I'll give IN a shot. If you have the option, would it be possible to test this without AMP, just in case it's a matter of numerical accuracy (which would not be caught by the grad scaler if not NaN) ? |
Just in case @jramapuram, could you check that you're using |
Thanks for keeping this in mind @blefaudeux. Just checked, using Re the minGPT: I'm surprised there is a perf drop -- does the test loss / negative-log-likelihood to follow the same trend? |
20220430 was fine, the ones after that were broken, but fixed by triton-lang/triton@205a493 so it's back to being good at the moment ! re-minGPT I can check the other metrics, as mentioned in another thread I think that it may be due to the distribution being hardcoded right now for deepnorm, I think it's not very readable, hackable, and not a great design overall, I'd like to come up with something better and more explicit (for instance with a couple of possible inits as part of the xformers config, and deepnorm respecting that). It's always possible to init from the outside, but it's tied to parameter naming conventions (not super clear right now), and it kind of negates the point of supporting deepnorm to begin with I think |
Unfortunately no joy @blefaudeux. I tried:
|
thanks a bunch @jramapuram ! I've a draft PR getting ready which rewrites a lot of the input projections (something we discussed earlier) + explicit handling of a couple of init methods (optional, users are still free to do as they please), I'm hoping that it solves this. To give an insight, I think that this setting is not well handled and could be the culprit (deepnorm assumes a different projection per Q/K/V, and the default here should probably be "true" I believe) |
I think that #312 is getting there @jramapuram, it's a lot cleaner to my eyes. Something I've seen, related to your curves above, is that it's not just deepnorm, the post- normalization path does not play well with ViT. GPT is fine with this nornalization path, I don't know if it's a known fact, I would need to check the literature. Since deepnorm is a subset of the post- normalization code path, it makes a little more sense, or at least it's not alone |
ok, beyond #312 which cleans things up, it looks like (given Timm, here) layernorm requires a specific treatment for ViT+Post, the weight is initialized to a very small value (vs. 1 typically). Since in our case Post & Deepnorm (same residual codepath) both fail with ViT but work well with GPT, it could explain why. I'll give that a shot |
I've not forgotten that @jramapuram, turns out that for vision / post norm Swin v2 already solved this (related to the message above), see their paper. The initial weights need to be scaled way down, I'll try to implement this in xformers when I get the time |
🐛 Bug
I'm trying to create a 1:1 config that can train a stable ViT-B with the MAE config (from appendix A.2).
Maybe I'm missing something (highly plausible), but when I use xformers instead of timm it creates an unstable training scenario [over numerous trials] with exactly the same hyper-parameters (batch_size=4096 + cutmix + mixup + label smoothing + AdamW[0.9, 0.95], lr=1e-4 [with scaling rule ofc], lr warmup + cosine decay, skip bias/CLS/pos_embed weight decay, etc, etc).
xformers ViT-B Config
xformers ViT-B
Command
To Reproduce
Steps to reproduce the behavior:
The text was updated successfully, but these errors were encountered: