-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bugfix for MNASNet #1224
Bugfix for MNASNet #1224
Conversation
…the batch dimension
Sure, we can make it backward compatible. Of the two proposed solutions I like solution 1 better (fewer moving parts), but I'll implement whatever we decide here. Maybe I should take this opportunity and add squeeze and excitation as well. Authors use it in some of the larger variants of the model. Accuracy will go up a bit if I do that, and then we won't have to version it again in the future. Is there a release schedule BTW? I wanted this fix to be ready before the release (models unfortunately take forever to train), and noticed the release had been cut a few days ago, so the PR didn't make it. |
Hi @1e100 , I just got back from holidays, sorry for the delay in replying.
I'm not yet clear on the best solution. I feel that this is something that needs to be carefully considered, because model versioning is going to be a big topic.
I feel that this should be sent in a separate PR.
We will be cutting a new release of torchvision in the next 2-3 weeks, with minor fixes and improvements. I'm also tagging @ailzhang for handling BC-breaking changes within hub, and @cpuhrsch @vincentqb and @zhangguanheng66 for torchaudio and torchtext model versioning in the future. |
using |
I had to handle the BC breaking for nn.MultiheadAttention. To extend the capability of the module, I add four extra attributes in the module. We used the For the second option, we have to instruct users to use |
OK, had some time today to look into this, aiming to get it done over the weekend. Basically, it seems that the following simple logic would satisfy the backward compat requirements:
Seems like a pretty straightforward fix to me. |
OK, @fmassa, here's the first cut of the requested changes. CI seems to be failing on something CUDA-related. Let me know if this is what you had in mind! |
…t could be reloaded with BC patching again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking very good, I like it!
I have a few more comments, let me know what you think.
Also, I am thinking that we might need to still add an option somewhere (maybe in the mnasnet_0_5
function), something that initially raises a warning if the user doesn't pass an argument `, saying that the default behavior will change in a new version, so that we don't break BC right away for the users?
torchvision/models/mnasnet.py
Outdated
self.layers[idx] = layer | ||
|
||
# The model is now identical to v1, and must be saved as such. | ||
MNASNet._version = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This modifies all instances of MNASNet
, and not the one being called.
This could have some unexpected effects, maybe you meant to do instead something like self._version
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
D'oh! You're right. Changed, and verified it works with this code:
#!/usr/bin/env python3
import torch
import torchvision
# NOTE: v1 checkpoint
ckpt = torch.load("mnasnet0.5_top1_67.592-7c6cb539b9.pth")
m = torchvision.models.MNASNet(0.5)
m.load_state_dict(ckpt)
print("Loaded old")
torch.save(m.state_dict(), "resaved.pth")
print("Re-saved")
ckpt = torch.load("resaved.pth")
m = torchvision.models.MNASNet(0.5)
m.load_state_dict(ckpt)
print("Re-loaded")
@@ -139,16 +149,58 @@ def _initialize_weights(self): | |||
nn.init.ones_(m.weight) | |||
nn.init.zeros_(m.bias) | |||
elif isinstance(m, nn.Linear): | |||
nn.init.normal_(m.weight, 0.01) | |||
nn.init.kaiming_uniform_(m.weight, mode="fan_out", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This changes the initialization scheme, does this yield better performance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may have very slightly improved the top1 on MNASNet b1 0.5 that I trained for this PR, but I'm not sure. The purpose of the change is that initialization is now identical to the reference TensorFlow code (which also uses a variance scaling initializer aka Kaiming uniform). Certainly not worse than before.
@fmassa I've addressed your feedback, PTAL |
Sorry for the delay in replying, I've made a few more comments |
Remove unused member var as per review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've addressed the feedback, PTAL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot!
I'll upload the weights and update the PR
@1e100 I couldn't push the updated path without a force push on your master branch. Can you update the link in the PR to
And let me know? |
@fmassa done! |
Thanks a lot! |
The original implementation I submitted contained a bug which affects all MNASNet variants other than 1.0. The bug is that the first few layers need to also be scaled in terms of width multiplier, along with all the rest. This fixes the issue, and brings the implementation fully in sync with Google's TPU reference code. I have compared the ONNX dump of this model against TFLite's hosted model and ensured that all layer configurations line up exactly.
Because only MNASNet 0.5 checkpoint was affected, I have also trained a slightly better checkpoint for it. I was unable to train this to the same accuracy with Torchvision's reference training code (and it wasn't for the lack of trying), and had to use label smoothing and EMA to get this result. The final checkpoint is derived from EMA.
Even so, the accuracy is a bit lower than Google's result (67.83 for this model vs 68.03 TFLite). The posted number for TF TPU implementation is 68.9, but that model uses SE on some of its layers, which my implementation does not.