Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix for MNASNet #1224

Merged
merged 44 commits into from
Sep 23, 2019
Merged
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
50bfbe6
Add initial mnasnet impl
1e100 Apr 2, 2019
e1c5506
Remove all type hints, comply with PyTorch overall style
1e100 Apr 2, 2019
0d77acc
Expose models
1e100 Apr 2, 2019
c41aaab
Remove avgpool from features() and add separately
1e100 Apr 2, 2019
d6115f9
Merge upstream
1e100 Apr 5, 2019
568bd50
Fix python3-only stuff, replace subclasses with functions
1e100 Apr 13, 2019
5617b8e
fix __all__
1e100 Apr 13, 2019
ba0ad4d
Fix typo
1e100 Apr 13, 2019
bd4836b
Remove conditional dropout
1e100 Apr 14, 2019
5ac43bd
Merge branch 'master' of github.com:1e100/vision
1e100 Apr 14, 2019
102ba55
Make dropout functional
1e100 Apr 15, 2019
9c8b827
Addressing @fmassa's feedback, round 1
1e100 Apr 16, 2019
2872b1f
Replaced adaptive avgpool with mean on H and W to prevent collapsing …
1e100 Apr 16, 2019
05b387b
Partially address feedback
1e100 May 3, 2019
2d39797
YAPF
1e100 May 3, 2019
8b5f7b9
Removed redundant class vars
1e100 May 3, 2019
8de71fe
Merge master
1e100 May 6, 2019
40471ac
Update urls to releases
1e100 May 6, 2019
b1d54ec
Add information to models.rst
1e100 May 6, 2019
ec717d0
Replace init with kaiming_normal_ in fan-out mode
1e100 May 11, 2019
8b2dba9
Use load_state_dict_from_url
1e100 May 12, 2019
06177ee
Merge master
1e100 May 21, 2019
c34df87
Merge master again
1e100 May 21, 2019
8b538ae
Merge branch 'master' of https://github.com/pytorch/vision
1e100 Jun 27, 2019
7be7478
Fix depth scaling on first 2 layers
1e100 Jun 27, 2019
e996c36
Restore initialization
1e100 Jun 30, 2019
1fc9c76
Match reference implementation initialization for dense layer
1e100 Jul 4, 2019
e5164e3
Meant to use Kaiming
1e100 Jul 4, 2019
f5c9a17
Remove spurious relu
1e100 Jul 12, 2019
1b7808e
Point to the newest 0.5 checkpoint
1e100 Aug 10, 2019
96eb194
Latest pretrained checkpoint
1e100 Aug 10, 2019
f50e776
Merge branch 'master' of https://github.com/pytorch/vision
1e100 Aug 10, 2019
0626a21
Restore 1.0 checkpoint
1e100 Aug 10, 2019
af9679d
YAPF
1e100 Aug 10, 2019
c611d0d
Implement backwards compat as suggested by Soumith
1e100 Sep 7, 2019
ed89aac
Update checkpoint URL
1e100 Sep 7, 2019
36fa9fa
Move warnings up
1e100 Sep 7, 2019
3ceed68
Record a couple more function parameters
1e100 Sep 7, 2019
b9e60c2
Update comment
1e100 Sep 7, 2019
2c8ccbc
Set the correct version such that if the BC-patched model is saved, i…
1e100 Sep 7, 2019
9165032
Merge branch 'master' of github.com:1e100/vision
1e100 Sep 7, 2019
061dade
Set a member var, not class var
1e100 Sep 14, 2019
d0a43c4
Update mnasnet.py
1e100 Sep 20, 2019
00ddb9d
Update the path to weights
1e100 Sep 20, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 70 additions & 19 deletions torchvision/models/mnasnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import warnings

import torch
import torch.nn as nn
Expand All @@ -8,7 +9,7 @@

_MODEL_URLS = {
"mnasnet0_5":
"https://download.pytorch.org/models/mnasnet0.5_top1_67.592-7c6cb539b9.pth",
"https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
"mnasnet0_75": None,
"mnasnet1_0":
"https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
Expand Down Expand Up @@ -74,14 +75,16 @@ def _round_to_multiple_of(val, divisor, round_up_bias=0.9):
return new_val if new_val >= round_up_bias * val else new_val + divisor


def _scale_depths(depths, alpha):
def _get_depths(alpha):
""" Scales tensor depths as in reference MobileNet code, prefers rouding up
rather than down. """
depths = [32, 16, 24, 40, 80, 96, 192, 320]
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]


class MNASNet(torch.nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf.
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
>>> model = MNASNet(1000, 1.0)
>>> x = torch.rand(1, 3, 224, 224)
>>> y = model(x)
Expand All @@ -90,30 +93,36 @@ class MNASNet(torch.nn.Module):
>>> y.nelement()
1000
"""
# Version 2 adds depth scaling in the initial stages of the network.
_version = 2

def __init__(self, alpha, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__()
depths = _scale_depths([24, 40, 80, 96, 192, 320], alpha)
assert alpha > 0.0
self.alpha = alpha
self.num_classes = num_classes
depths = _get_depths(alpha)
layers = [
# First layer: regular conv.
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
# Depthwise separable, no skip.
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1,
groups=depths[0], bias=False),
nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM),
# MNASNet blocks: stacks of inverted residuals.
_stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM),
_stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM),
_stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM),
_stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM),
_stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM),
_stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM),
_stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM),
_stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM),
_stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM),
# Final mapping to classifier input.
nn.Conv2d(depths[5], 1280, 1, padding=0, stride=1, bias=False),
nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
]
Expand All @@ -139,16 +148,58 @@ def _initialize_weights(self):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0.01)
nn.init.kaiming_uniform_(m.weight, mode="fan_out",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the initialization scheme, does this yield better performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may have very slightly improved the top1 on MNASNet b1 0.5 that I trained for this PR, but I'm not sure. The purpose of the change is that initialization is now identical to the reference TensorFlow code (which also uses a variance scaling initializer aka Kaiming uniform). Certainly not worse than before.

nonlinearity="sigmoid")
nn.init.zeros_(m.bias)

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get("version", None)
assert version in [1, 2]

if version == 1 and not self.alpha == 1.0:
# In the initial version of the model (v1), stem was fixed-size.
# All other layer configurations were the same. This will patch
# the model so that it's identical to v1. Model with alpha 1.0 is
# unaffected.
depths = _get_depths(self.alpha)
v1_stem = [
nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32,
bias=False),
nn.BatchNorm2d(32, momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True),
nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(16, momentum=_BN_MOMENTUM),
_stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM),
]
for idx, layer in enumerate(v1_stem):
self.layers[idx] = layer

# The model is now identical to v1, and must be saved as such.
self._version = 1
warnings.warn(
"A new version of MNASNet model has been implemented. "
"Your checkpoint was saved using the previous version. "
"This checkpoint will load and work as before, but "
"you may want to upgrade by training a newer model or "
"transfer learning from an updated ImageNet checkpoint.",
UserWarning)

super(MNASNet, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)


def _load_pretrained(model_name, model, progress):
if model_name not in _MODEL_URLS or _MODEL_URLS[model_name] is None:
raise ValueError(
"No checkpoint is available for model type {}".format(model_name))
checkpoint_url = _MODEL_URLS[model_name]
model.load_state_dict(load_state_dict_from_url(checkpoint_url, progress=progress))
model.load_state_dict(
load_state_dict_from_url(checkpoint_url, progress=progress))


def mnasnet0_5(pretrained=False, progress=True, **kwargs):
Expand Down